juliuse commited on
Commit
90a9dd3
·
1 Parent(s): 0c72eec

Initial commit: track binaries with LFS

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .gitignore +177 -0
  3. DIV2k_mask.npy +3 -0
  4. LCFM/SD3.5-large_MSE_DIV2k.npy +3 -0
  5. README.md +70 -14
  6. app.py +774 -0
  7. assets/teaser3.svg +0 -0
  8. configs/inpainting.yaml +45 -0
  9. configs/inpainting_gradio.yaml +45 -0
  10. configs/motion_deblur.yaml +43 -0
  11. configs/x12.yaml +48 -0
  12. configs/x12_gradio.yaml +48 -0
  13. demo_images/demo_0_image.png +3 -0
  14. demo_images/demo_0_meta.json +7 -0
  15. demo_images/demo_1_image.png +3 -0
  16. demo_images/demo_1_mask.png +3 -0
  17. demo_images/demo_1_meta.json +7 -0
  18. demo_images/demo_2_image.png +3 -0
  19. demo_images/demo_2_mask.png +3 -0
  20. demo_images/demo_2_meta.json +7 -0
  21. demo_images/demo_3_image.png +3 -0
  22. demo_images/demo_3_mask.png +3 -0
  23. demo_images/demo_3_meta.json +7 -0
  24. examples/girl.png +3 -0
  25. examples/sunflowers.png +3 -0
  26. inference_scripts/run_image_inv.py +159 -0
  27. requirements.txt +18 -0
  28. scripts/compute_metrics.py +144 -0
  29. scripts/generate_caption.py +107 -0
  30. setup.py +14 -0
  31. src/flair/__init__.py +0 -0
  32. src/flair/degradations.py +198 -0
  33. src/flair/functions/__init__.py +0 -0
  34. src/flair/functions/ckpt_util.py +72 -0
  35. src/flair/functions/conjugate_gradient.py +66 -0
  36. src/flair/functions/degradation.py +211 -0
  37. src/flair/functions/jpeg.py +392 -0
  38. src/flair/functions/measurements.py +429 -0
  39. src/flair/functions/nonuniform/kernels/000001.npy +3 -0
  40. src/flair/functions/svd_ddnm.py +206 -0
  41. src/flair/functions/svd_operators.py +1308 -0
  42. src/flair/helper_functions.py +31 -0
  43. src/flair/pipelines/__init__.py +0 -0
  44. src/flair/pipelines/model_loader.py +97 -0
  45. src/flair/pipelines/sd3.py +111 -0
  46. src/flair/pipelines/utils.py +114 -0
  47. src/flair/scheduling.py +882 -0
  48. src/flair/utils.py +406 -0
  49. src/flair/utils/__init__.py +0 -0
  50. src/flair/utils/blur_util.py +48 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ SD3_*
2
+ *.pdf
3
+ sd3_*
4
+ wandb/*
5
+ output/*
6
+ data/*
7
+ vscode/*
8
+ # Byte-compiled / optimized / DLL files
9
+ __pycache__/
10
+ *.py[cod]
11
+ *$py.class
12
+ media/*
13
+ pose/*
14
+ wandb/*
15
+ thirdparty/*
16
+ # C extensions
17
+ *.so
18
+
19
+ # Distribution / packaging
20
+ .Python
21
+ build/
22
+ develop-eggs/
23
+ dist/
24
+ downloads/
25
+ eggs/
26
+ .eggs/
27
+ lib/
28
+ lib64/
29
+ parts/
30
+ sdist/
31
+ var/
32
+ wheels/
33
+ share/python-wheels/
34
+ *.egg-info/
35
+ .installed.cfg
36
+ *.egg
37
+ MANIFEST
38
+
39
+ # PyInstaller
40
+ # Usually these files are written by a python script from a template
41
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
42
+ *.manifest
43
+ *.spec
44
+
45
+ # Installer logs
46
+ pip-log.txt
47
+ pip-delete-this-directory.txt
48
+
49
+ # Unit test / coverage reports
50
+ htmlcov/
51
+ .tox/
52
+ .nox/
53
+ .coverage
54
+ .coverage.*
55
+ .cache
56
+ nosetests.xml
57
+ coverage.xml
58
+ *.cover
59
+ *.py,cover
60
+ .hypothesis/
61
+ .pytest_cache/
62
+ cover/
63
+
64
+ # Translations
65
+ *.mo
66
+ *.pot
67
+
68
+ # Django stuff:
69
+ *.log
70
+ local_settings.py
71
+ db.sqlite3
72
+ db.sqlite3-journal
73
+
74
+ # Flask stuff:
75
+ instance/
76
+ .webassets-cache
77
+
78
+ # Scrapy stuff:
79
+ .scrapy
80
+
81
+ # Sphinx documentation
82
+ docs/_build/
83
+
84
+ # PyBuilder
85
+ .pybuilder/
86
+ target/
87
+
88
+ # Jupyter Notebook
89
+ .ipynb_checkpoints
90
+
91
+ # IPython
92
+ profile_default/
93
+ ipython_config.py
94
+
95
+ # pyenv
96
+ # For a library or package, you might want to ignore these files since the code is
97
+ # intended to run in multiple environments; otherwise, check them in:
98
+ # .python-version
99
+
100
+ # pipenv
101
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
102
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
103
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
104
+ # install all needed dependencies.
105
+ #Pipfile.lock
106
+
107
+ # poetry
108
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
109
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
110
+ # commonly ignored for libraries.
111
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
112
+ #poetry.lock
113
+
114
+ # pdm
115
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
116
+ #pdm.lock
117
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
118
+ # in version control.
119
+ # https://pdm.fming.dev/#use-with-ide
120
+ .pdm.toml
121
+
122
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
123
+ __pypackages__/
124
+
125
+ # Celery stuff
126
+ celerybeat-schedule
127
+ celerybeat.pid
128
+
129
+ # SageMath parsed files
130
+ *.sage.py
131
+
132
+ # Environments
133
+ .env
134
+ .venv
135
+ env/
136
+ venv/
137
+ ENV/
138
+ env.bak/
139
+ venv.bak/
140
+
141
+ # Spyder project settings
142
+ .spyderproject
143
+ .spyproject
144
+
145
+ # Rope project settings
146
+ .ropeproject
147
+
148
+ # mkdocs documentation
149
+ /site
150
+
151
+ # mypy
152
+ .mypy_cache/
153
+ .dmypy.json
154
+ dmypy.json
155
+
156
+ # Pyre type checker
157
+ .pyre/
158
+
159
+ # pytype static type analyzer
160
+ .pytype/
161
+
162
+ # Cython debug symbols
163
+ cython_debug/
164
+
165
+ # PyCharm
166
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
167
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
168
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
169
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
170
+ #.idea/
171
+ slurm*.out
172
+ prompt_cache/*
173
+ .vscode/
174
+ cache/
175
+ notebooks/
176
+ paper_vis_scripts/
177
+ sweeps/
DIV2k_mask.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:774e6caa5f557de4bd3bd7e0df1de2b770974f96d703a320941ab7fc3cf66ad1
3
+ size 589952
LCFM/SD3.5-large_MSE_DIV2k.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d4ee67869328e73b1989ec52556c21f2aed2c6c6f9e90ef33e047b2514d22d31
3
+ size 528
README.md CHANGED
@@ -1,14 +1,70 @@
1
- ---
2
- title: FLAIR
3
- emoji: 😻
4
- colorFrom: yellow
5
- colorTo: gray
6
- sdk: gradio
7
- sdk_version: 5.32.1
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- short_description: Training-Free Super Resolution and Inpainting with FLAIR
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ # FLAIR: Flow-Based Latent Alignment for Image Restoration
4
+
5
+ **Julius Erbach<sup>1</sup>, Dominik Narnhofer<sup>1</sup>, Andreas Dombos<sup>1</sup>, Jan Eric Lenssen<sup>1</sup>, Bernt Schiele<sup>2</sup>, Konrad Schindler<sup>1</sup>**
6
+ <br>
7
+ <sup>1</sup> Photogrammetry and Remote Sensing, ETH Zurich
8
+ <sup>2</sup> Max Planck Institute for Informatics, Saarbrücken
9
+
10
+ [![Paper](https://img.shields.io/badge/arXiv-PDF-b31b1b)](link)
11
+ [![Page](https://img.shields.io/badge/Project-Page-green)](inverseFLAIR.github.io)
12
+ [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](link)
13
+ </div>
14
+
15
+ <p align="center">
16
+ <img src="assets/teaser3.svg" alt="teaser" width=98%"/>
17
+ </p>
18
+ <p align="center">
19
+ <emph>FLAIR</emph> is a novel approach for solving inverse imaging problems using flow-based posterior sampling.
20
+ </p>
21
+
22
+ ## Installation
23
+
24
+ 1. Clone the repository:
25
+ ```bash
26
+ git clone <your-repo-url>
27
+ cd <your-repo-name>
28
+ ```
29
+
30
+ 2. Create a virtual environment (recommended):
31
+ ```bash
32
+ python3 -m venv venv
33
+ source venv/bin/activate
34
+ ```
35
+
36
+ 3. Install the required dependencies from `requirements.txt`:
37
+ ```bash
38
+ pip install -r requirements.txt
39
+ pip install .
40
+ ```
41
+
42
+ ## Running Inference
43
+
44
+ To run inference, you can use one of the Python script run_image_inv.py with the according config file.
45
+ An example from the FFQH dataset
46
+ ```bash
47
+ python inference_scripts/run_image_inv.py --config configs/inpainting.yaml --target_file examples/girl.png --result_folder output --prompt="a high quality photo of a face"
48
+ ```
49
+ Or an example from the DIV2K dataset with captions provided by DAPE using the degraded input. The masks can be defined as rectanlge coordinates in the config file or provided as .npy file where true pixels are observed and false are masked out.
50
+
51
+ ```bash
52
+ python inference_scripts/run_image_inv.py --config configs/inpainting.yaml --target_file examples/sunflowers.png --result_folder output --prompt="a high quality photo of bloom, blue, field, flower, sky, sunflower, sunflower field, yellow" --mask_file DIV2k_mask.npy
53
+ ```
54
+
55
+ ```bash
56
+ python inference_scripts/run_image_inv.py --config configs/x12.yaml --target_file examples/sunflowers.png --result_folder output --prompt="a high quality photo of bloom, blue, field, flower, sky, sunflower, sunflower field, yellow"
57
+ ```
58
+
59
+ ## Citation
60
+
61
+ If you find this work useful in your research, please cite our paper:
62
+
63
+ ```bibtex
64
+ @article{er2025solving,
65
+ title={Solving Inverse Problems with FLAIR},
66
+ author={Erbach, Julius and Narnhofer, Dominik and Dombos, Andreas and Lenssen, Jan Eric and Schiele, Bernt and Schindler, Konrad},
67
+ journal={arXiv},
68
+ year={2025}
69
+ }
70
+ ```
app.py ADDED
@@ -0,0 +1,774 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from gradio_imageslider import ImageSlider # Replaces gr.ImageCompare
3
+ import torch
4
+ import yaml
5
+ import numpy as np
6
+ from PIL import Image
7
+ import torchvision.transforms.functional as TF
8
+ import random
9
+ import os
10
+ import sys
11
+ import json # Added import
12
+ import copy
13
+ # Add project root to sys.path to allow direct import of var_post_samp
14
+ project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "."))
15
+ if project_root not in sys.path:
16
+ sys.path.insert(0, project_root)
17
+
18
+ from flair.pipelines import model_loader
19
+ from flair import var_post_samp, degradations
20
+
21
+
22
+
23
+ CONFIG_FILE_PATH = "./configs/inpainting_gradio.yaml"
24
+ DTYPE = torch.bfloat16
25
+
26
+ # Global variables to hold the model and config
27
+ MODEL = None
28
+ POSTERIOR_MODEL = None
29
+ BASE_CONFIG = None
30
+ DEVICES = None
31
+ PRIMARY_DEVICE = None
32
+ # project_root is already defined globally, will be used by save_configuration
33
+
34
+ SR_CONFIG_FILE_PATH = "./configs/x12_gradio.yaml"
35
+
36
+ # Function to save the current configuration for demo examples
37
+ def save_configuration(image_editor_data, image_input, prompt, seed_val, task, random_seed_bool, steps_val):
38
+ global project_root # Ensure access to the globally defined project_root
39
+ if task == "Super Resolution":
40
+ if image_input is None:
41
+ return gr.Markdown("""<p style='color:red;'>Error: No low-resolution image loaded.</p>""")
42
+ # For Super Resolution, we don't need a mask, just the image
43
+ input_image = image_input
44
+ mask_image = None
45
+ else: # Inpainting task
46
+ if image_editor_data is None or image_editor_data['background'] is None:
47
+ return gr.Markdown("""<p style='color:red;'>Error: No background image loaded.</p>""")
48
+
49
+ # Check if layers exist and the first layer (mask) is not None
50
+ if not image_editor_data['layers'] or image_editor_data['layers'][0] is None:
51
+ return gr.Markdown("""<p style='color:red;'>Error: No mask drawn. Please use the brush tool to draw a mask.</p>""")
52
+
53
+ input_image = image_editor_data['background']
54
+ mask_image = image_editor_data['layers'][0]
55
+
56
+ metadata = {
57
+ "prompt": prompt,
58
+ "seed_on_slider": int(seed_val),
59
+ "use_random_seed_checkbox": bool(random_seed_bool),
60
+ "num_steps": int(steps_val),
61
+ "task_type": task # Always inpainting for now
62
+ }
63
+
64
+ demo_images_dir = os.path.join(project_root, "demo_images")
65
+ try:
66
+ os.makedirs(demo_images_dir, exist_ok=True)
67
+ except Exception as e:
68
+ return gr.Markdown(f"""<p style='color:red;'>Error creating directory {demo_images_dir}: {str(e)}</p>""")
69
+
70
+ i = 0
71
+ while True:
72
+ base_filename = f"demo_{i}"
73
+ meta_check_path = os.path.join(demo_images_dir, f"{base_filename}_meta.json")
74
+ if not os.path.exists(meta_check_path):
75
+ break
76
+ i += 1
77
+
78
+ image_save_path = os.path.join(demo_images_dir, f"{base_filename}_image.png")
79
+ mask_save_path = os.path.join(demo_images_dir, f"{base_filename}_mask.png")
80
+ meta_save_path = os.path.join(demo_images_dir, f"{base_filename}_meta.json")
81
+
82
+ try:
83
+ input_image.save(image_save_path)
84
+ if mask_image is not None:
85
+ # Ensure mask is saved in a usable format, e.g., 'L' mode for grayscale, or 'RGBA' if it has transparency
86
+ if mask_image.mode != 'L' and mask_image.mode != '1': # If not already grayscale or binary
87
+ mask_image = mask_image.convert('RGBA') # Preserve transparency if drawn, or convert to L
88
+ mask_image.save(mask_save_path)
89
+
90
+ with open(meta_save_path, 'w') as f:
91
+ json.dump(metadata, f, indent=4)
92
+ return gr.Markdown(f"""<p style='color:green;'>Configuration saved as {base_filename} in demo_images folder.</p>""")
93
+ except Exception as e:
94
+ return gr.Markdown(f"""<p style='color:red;'>Error saving configuration: {str(e)}</p>""")
95
+
96
+ def embed_prompt(prompt, device):
97
+ print(f"Generating prompt embeddings for: {prompt}")
98
+ with torch.no_grad(): # Add torch.no_grad() here
99
+ POSTERIOR_MODEL.model.text_encoder.to(device).to(torch.bfloat16)
100
+ POSTERIOR_MODEL.model.text_encoder_2.to(device).to(torch.bfloat16)
101
+ POSTERIOR_MODEL.model.text_encoder_3.to(device).to(torch.bfloat16)
102
+ (
103
+ prompt_embeds,
104
+ negative_prompt_embeds,
105
+ pooled_prompt_embeds,
106
+ negative_pooled_prompt_embeds,
107
+ ) = POSTERIOR_MODEL.model.encode_prompt(
108
+ prompt=prompt,
109
+ prompt_2=prompt,
110
+ prompt_3=prompt,
111
+ negative_prompt="",
112
+ negative_prompt_2="",
113
+ negative_prompt_3="",
114
+ do_classifier_free_guidance=POSTERIOR_MODEL.model.do_classifier_free_guidance,
115
+ prompt_embeds=None,
116
+ negative_prompt_embeds=None,
117
+ pooled_prompt_embeds=None,
118
+ negative_pooled_prompt_embeds=None,
119
+ device=device,
120
+ clip_skip=None,
121
+ num_images_per_prompt=1,
122
+ max_sequence_length=256,
123
+ lora_scale=None,
124
+ )
125
+ # POSTERIOR_MODEL.model.text_encoder.to("cpu").to(torch.bfloat16)
126
+ # POSTERIOR_MODEL.model.text_encoder_2.to("cpu").to(torch.bfloat16)
127
+ # POSTERIOR_MODEL.model.text_encoder_3.to("cpu").to(torch.bfloat16)
128
+ torch.cuda.empty_cache() # Clear GPU memory after embedding generation
129
+ return {
130
+ "prompt_embeds": prompt_embeds.to(device, dtype=DTYPE),
131
+ "negative_prompt_embeds": negative_prompt_embeds.to(device, dtype=DTYPE) if negative_prompt_embeds is not None else None,
132
+ "pooled_prompt_embeds": pooled_prompt_embeds.to(device, dtype=DTYPE),
133
+ "negative_pooled_prompt_embeds": negative_pooled_prompt_embeds.to(device, dtype=DTYPE) if negative_pooled_prompt_embeds is not None else None
134
+ }
135
+
136
+ def initialize_globals():
137
+ global MODEL, POSTERIOR_MODEL, BASE_CONFIG, DEVICES, PRIMARY_DEVICE
138
+
139
+ print("Global initialization started...")
140
+ # Setup device (run once)
141
+ if torch.cuda.is_available():
142
+ num_gpus = torch.cuda.device_count()
143
+ DEVICES = [f"cuda:{i}" for i in range(num_gpus)]
144
+ PRIMARY_DEVICE = DEVICES[0]
145
+ print(f"Initializing with devices: {DEVICES}, Primary: {PRIMARY_DEVICE}")
146
+ else:
147
+ DEVICES = ["cpu"]
148
+ PRIMARY_DEVICE = "cpu"
149
+ print("No CUDA devices found. Initializing with CPU.")
150
+
151
+ # Load base configuration (once)
152
+ with open(CONFIG_FILE_PATH, "r") as f:
153
+ BASE_CONFIG = yaml.safe_load(f)
154
+
155
+ # Prepare a temporary config for the initial model and posterior_model loading
156
+ init_config = BASE_CONFIG.copy()
157
+
158
+ # Ensure prompt/caption settings are valid for model_loader for initialization
159
+ # Forcing prompt mode for initial load.
160
+ init_config["prompt"] = [BASE_CONFIG.get("prompt", "Initialization prompt")]
161
+ init_config["caption_file"] = None
162
+
163
+ # Default values that might be needed by model_loader or utils called within
164
+ init_config.setdefault("target_file", "dummy_target.png")
165
+ init_config.setdefault("result_file", "dummy_results/")
166
+ init_config.setdefault("seed", random.randint(0, 2**32 - 1)) # Init with a random seed
167
+
168
+ print("Loading base model and variational posterior model once...")
169
+ # MODEL is the main diffusion model, loaded once.
170
+ # inp_kwargs_for_init are based on init_config, not directly used for subsequent inferences.
171
+ model_obj, _ = model_loader.load_model(init_config, device=DEVICES)
172
+ MODEL = model_obj
173
+
174
+ # Initialize VariationalPosterior once with the loaded MODEL and init_config.
175
+ # Its internal forward_operator will be based on init_config's degradation settings,
176
+ # but will be replaced in each inpaint_image call.
177
+ POSTERIOR_MODEL = var_post_samp.VariationalPosterior(MODEL, init_config)
178
+ print("Global initialization complete.")
179
+
180
+
181
+ def load_config_for_inference(prompt_text, seed=None):
182
+ # This function is now for creating a temporary config for each inference call,
183
+ # primarily to get up-to-date inp_kwargs via model_loader.
184
+ # It starts from BASE_CONFIG and applies current overrides.
185
+ if BASE_CONFIG is None:
186
+ raise RuntimeError("Base config not initialized. Call initialize_globals().")
187
+
188
+ current_config = BASE_CONFIG.copy()
189
+
190
+ current_config["prompt"] = [prompt_text] # Override with user's prompt
191
+ current_config["caption_file"] = None # Ensure we are in prompt mode
192
+
193
+ if seed is None:
194
+ seed = current_config.get("seed", random.randint(0, 2**32 - 1))
195
+ current_config["seed"] = seed
196
+ # Set global seeds for reproducibility for the current call
197
+ torch.manual_seed(seed)
198
+ np.random.seed(seed)
199
+ random.seed(seed)
200
+ print(f"Using seed for current inference: {seed}")
201
+
202
+ # Ensure other necessary fields are in 'current_config' if model_loader needs them
203
+ current_config.setdefault("target_file", "dummy_target.png")
204
+ current_config.setdefault("result_file", "dummy_results/")
205
+
206
+ return current_config
207
+
208
+ def preprocess_image(pil_image, resolution, is_mask=False):
209
+ img = pil_image.convert("RGB") if not is_mask else pil_image.convert("L")
210
+
211
+ # Calculate new dimensions to maintain aspect ratio, making shorter edge 'resolution'
212
+ original_width, original_height = img.size
213
+ if original_width < original_height:
214
+ new_short_edge = resolution
215
+ new_long_edge = int(resolution * (original_height / original_width))
216
+ new_width = new_short_edge
217
+ new_height = new_long_edge
218
+ else:
219
+ new_short_edge = resolution
220
+ new_long_edge = int(resolution * (original_width / original_height))
221
+ new_height = new_short_edge
222
+ new_width = new_long_edge
223
+
224
+ # TF.resize expects [height, width]
225
+ img = TF.resize(img, [new_height, new_width], interpolation=TF.InterpolationMode.LANCZOS)
226
+
227
+ # Center crop to the target square resolution
228
+ img = TF.center_crop(img, [resolution, resolution])
229
+
230
+ img_tensor = TF.to_tensor(img) # Scales to [0, 1]
231
+ if is_mask:
232
+ # Ensure mask is binary (0 or 1), 1 for region to inpaint
233
+ # The mask from ImageEditor is RGBA, convert to L first.
234
+ img = img.convert('L')
235
+ img_tensor = TF.to_tensor(img) # Recalculate tensor after convert
236
+ img_tensor = (img_tensor == 0.) # Threshold for mask (drawn parts are usually non-black)
237
+ img_tensor = img_tensor.repeat(3, 1, 1) # Repeat mask across 3 channels
238
+ else:
239
+ # Normalize image to [-1, 1]
240
+ img_tensor = img_tensor * 2 - 1
241
+ return img_tensor.unsqueeze(0) # Add batch dimension
242
+
243
+ def preprocess_lr_image(pil_image, resolution, device, dtype):
244
+ if pil_image is None:
245
+ raise ValueError("Input PIL image cannot be None.")
246
+ img = pil_image.convert("RGB")
247
+
248
+ # Center crop to the target square resolution (no resizing)
249
+ img = TF.center_crop(img, [resolution, resolution])
250
+
251
+ img_tensor = TF.to_tensor(img) # Scales to [0, 1]
252
+ # Normalize image to [-1, 1]
253
+ img_tensor = img_tensor * 2 - 1
254
+ return img_tensor.unsqueeze(0).to(device, dtype=dtype) # Add batch dimension and move to device
255
+
256
+
257
+ def postprocess_image(image_tensor):
258
+ # Remove batch dimension, move to CPU, convert to float
259
+ image_tensor = image_tensor.squeeze(0).cpu().float()
260
+ # Denormalize from [-1, 1] to [0, 1]
261
+ image_tensor = image_tensor * 0.5 + 0.5
262
+ # Clip values to [0, 1]
263
+ image_tensor = torch.clamp(image_tensor, 0, 1)
264
+ # Convert to PIL Image
265
+ pil_image = TF.to_pil_image(image_tensor)
266
+ return pil_image
267
+
268
+ def inpaint_image(image_editor_output, prompt_text, fixed_seed_value, use_random_seed, guidance_scale, num_steps): # MODIFIED: seed_input changed to fixed_seed_value, use_random_seed
269
+ try:
270
+ if image_editor_output is None:
271
+ raise gr.Error("Please upload an image and draw a mask.")
272
+
273
+ input_pil = image_editor_output['background']
274
+
275
+ if not image_editor_output['layers'] or image_editor_output['layers'][0] is None:
276
+ raise gr.Error("Please draw a mask on the image using the brush tool.")
277
+ mask_pil = image_editor_output['layers'][0]
278
+
279
+
280
+ if input_pil is None:
281
+ raise gr.Error("Please upload an image.")
282
+ if mask_pil is None:
283
+ raise gr.Error("Please draw a mask on the image.")
284
+
285
+ current_seed = None
286
+ if use_random_seed:
287
+ current_seed = None # load_config_for_inference will generate a random seed
288
+ else:
289
+ try:
290
+ current_seed = int(fixed_seed_value)
291
+ except ValueError:
292
+ # This should ideally not happen with a slider, but good for robustness
293
+ raise gr.Error("Seed must be an integer.")
294
+
295
+ # Prepare config for current inference (gets prompt, seed)
296
+ current_config = load_config_for_inference(prompt_text, current_seed)
297
+ resolution = current_config["resolution"]
298
+
299
+ # MODIFIED: Set num_steps from slider into the current_config
300
+ # Assuming 'num_steps' is a key POSTERIOR_MODEL will use from its config.
301
+ # Common alternatives could be current_config['solver_kwargs']['n_steps'] = num_steps
302
+ current_config['n_steps'] = int(num_steps)
303
+ print(f"Using num_steps: {current_config['n_steps']}")
304
+
305
+
306
+ # Preprocess image and mask
307
+ guidance_img_tensor = preprocess_image(input_pil, resolution, is_mask=False).to(PRIMARY_DEVICE, dtype=DTYPE)
308
+ # Mask from ImageEditor is RGBA, preprocess_image will handle conversion to L and then binary
309
+ mask_tensor = preprocess_image(mask_pil, resolution, is_mask=True).to(PRIMARY_DEVICE, dtype=DTYPE)
310
+
311
+ # Get inp_kwargs for the CURRENT prompt and config.
312
+ print("Preparing inference inputs (e.g., prompt embeddings)...")
313
+ prompt_embeds = embed_prompt(prompt_text, device=PRIMARY_DEVICE) # Embed the prompt for the current inference
314
+ current_inp_kwargs = prompt_embeds
315
+ # MODIFIED: Use guidance_scale from slider
316
+ current_inp_kwargs['guidance'] = float(guidance_scale)
317
+ print(f"Using guidance_scale: {current_inp_kwargs['guidance']}")
318
+
319
+ # Update the global POSTERIOR_MODEL's config for this call.
320
+ # This ensures its methods use the latest settings (like num_steps) if they access self.config.
321
+ POSTERIOR_MODEL.config = current_config
322
+ POSTERIOR_MODEL.model._guidance_scale = guidance_scale
323
+ print("Applying forward operator (masking)...")
324
+ # Directly set the forward_operator on the global POSTERIOR_MODEL instance
325
+ # H and W are height and width of the guidance image tensor
326
+ POSTERIOR_MODEL.forward_operator = degradations.Inpainting(
327
+ mask=mask_tensor.bool()[0], # Inpainting often expects a boolean mask
328
+ H=guidance_img_tensor.shape[2],
329
+ W=guidance_img_tensor.shape[3],
330
+ noise_std=0,
331
+ )
332
+ y = POSTERIOR_MODEL.forward_operator(guidance_img_tensor)
333
+
334
+ print("Running inference...")
335
+ with torch.no_grad():
336
+ # Use the global POSTERIOR_MODEL instance
337
+ result_dict = POSTERIOR_MODEL.forward(y, current_inp_kwargs)
338
+
339
+ x_hat = result_dict["x_hat"]
340
+
341
+ print("Postprocessing result...")
342
+ output_pil = postprocess_image(x_hat)
343
+
344
+ # Convert mask tensor to PIL image for display
345
+ # Mask tensor is [0, 1], take one channel, convert to PIL
346
+ mask_display_tensor = mask_tensor.squeeze(0).cpu().float() # Remove batch, move to CPU
347
+ # If mask_tensor was (B, 3, H, W) and binary 0 or 1 (after repeat)
348
+ # We can take any channel, e.g., mask_display_tensor[0]
349
+ # Ensure it's (H, W) or (1, H, W) for to_pil_image
350
+ if mask_display_tensor.ndim == 3 and mask_display_tensor.shape[0] == 3: # (C, H, W)
351
+ mask_display_tensor = mask_display_tensor[0] # Take one channel (H, W)
352
+
353
+ # Ensure it's in the range [0, 1] and suitable for PIL conversion
354
+ # If it was 0. for masked and 1. for unmasked (or vice-versa depending on logic)
355
+ # TF.to_pil_image expects [0,1] for single channel float
356
+ mask_pil_display = TF.to_pil_image(mask_display_tensor)
357
+
358
+ return output_pil, [output_pil, output_pil], current_config["seed"] # MODIFIED: Removed mask_pil_display
359
+ except gr.Error as e: # Handle Gradio-specific errors first
360
+ raise
361
+ except Exception as e:
362
+ print(f"Error during inpainting: {e}")
363
+ import traceback # Ensure traceback is imported here if not globally
364
+ traceback.print_exc()
365
+ # Return a more user-friendly error message to Gradio
366
+ raise gr.Error(f"An error occurred: {str(e)}. Check console for details.")
367
+
368
+ def super_resolution_image(lr_image, prompt_text, fixed_seed_value, use_random_seed, guidance_scale, num_steps, sr_scale_factor, downscale_input):
369
+ try:
370
+ if lr_image is None:
371
+ raise gr.Error("Please upload a low-resolution image.")
372
+
373
+ current_seed = None
374
+ if use_random_seed:
375
+ current_seed = random.randint(0, 2**32 - 1)
376
+ else:
377
+ try:
378
+ current_seed = int(fixed_seed_value)
379
+ except ValueError:
380
+ raise gr.Error("Seed must be an integer.")
381
+
382
+ # Load Super-Resolution specific configuration
383
+ if not os.path.exists(SR_CONFIG_FILE_PATH):
384
+ raise gr.Error(f"Super-resolution config file not found: {SR_CONFIG_FILE_PATH}")
385
+ with open(SR_CONFIG_FILE_PATH, "r") as f:
386
+ sr_base_config = yaml.safe_load(f)
387
+
388
+ current_sr_config = copy.deepcopy(sr_base_config) # Start with a copy of the base SR config
389
+ current_sr_config["prompt"] = [prompt_text]
390
+ current_sr_config["caption_file"] = None # Ensure prompt mode
391
+ current_sr_config["seed"] = current_seed
392
+
393
+ torch.manual_seed(current_seed)
394
+ np.random.seed(current_seed)
395
+ random.seed(current_seed)
396
+ print(f"Using seed for SR inference: {current_seed}")
397
+
398
+ current_sr_config['n_steps'] = int(num_steps)
399
+ current_sr_config["degradation"]["kwargs"]["scale"] = sr_scale_factor
400
+ current_sr_config["optimizer_dataterm"]["kwargs"]["lr"] = sr_base_config.get("optimizer_dataterm", {}).get("kwargs", {}).get("lr") * sr_scale_factor**2 / (sr_base_config.get("degradation", {}).get("kwargs", {}).get("scale")**2)
401
+ print(f"Using num_steps for SR: {current_sr_config['n_steps']}")
402
+
403
+ # Determine target HR resolution for the output
404
+ hr_resolution = current_sr_config.get("degradation", {}).get("kwargs", {}).get("img_size")
405
+ # Calculate target LR dimensions based on the chosen scale factor
406
+ target_lr_width = int(hr_resolution / sr_scale_factor)
407
+ target_lr_height = int(hr_resolution / sr_scale_factor)
408
+ print(f"Target LR dimensions for SR: {target_lr_width}x{target_lr_height} for scale x{sr_scale_factor}")
409
+
410
+ print("Preparing SR inference inputs (prompt embeddings)...")
411
+ prompt_embeds = embed_prompt(prompt_text, device=PRIMARY_DEVICE)
412
+ current_inp_kwargs = prompt_embeds
413
+ current_inp_kwargs['guidance'] = float(guidance_scale)
414
+ print(f"Using guidance_scale for SR: {current_inp_kwargs['guidance']}")
415
+
416
+ POSTERIOR_MODEL.config = current_sr_config
417
+ POSTERIOR_MODEL.model._guidance_scale = float(guidance_scale)
418
+
419
+ print("Applying SR forward operator...")
420
+
421
+ POSTERIOR_MODEL.forward_operator = degradations.SuperResGradio(
422
+ **current_sr_config["degradation"]["kwargs"]
423
+ )
424
+
425
+ if downscale_input:
426
+ y_tensor = preprocess_lr_image(lr_image, hr_resolution, PRIMARY_DEVICE, DTYPE)
427
+ # y_tensor = POSTERIOR_MODEL.forward_operator(y_tensor)
428
+ y_tensor = torch.nn.functional.interpolate(y_tensor, scale_factor=1/sr_scale_factor, mode='bilinear', align_corners=False, antialias=True)
429
+ # simulate 8bit input by quantizing to 8-bit
430
+ y_tensor = ((y_tensor * 127.5 + 127.5).clamp(0, 255).to(torch.uint8) / 127.5 - 1.0).to(DTYPE)
431
+ else:
432
+ # check if the input image has the correct dimensions
433
+ if lr_image.size[0] != target_lr_width or lr_image.size[1] != target_lr_height:
434
+ raise gr.Error(f"Input image must be {target_lr_width}x{target_lr_height} pixels for the selected scale factor of {sr_scale_factor}.")
435
+ y_tensor = preprocess_lr_image(lr_image, target_lr_width, PRIMARY_DEVICE, DTYPE)
436
+ # add some noise to the input image
437
+ noise_std = current_sr_config.get("degradation", {}).get("kwargs", {}).get("noise_std", 0.0)
438
+ y_tensor += torch.randn_like(y_tensor) * noise_std
439
+ # save for debugging purposes
440
+ # first convert to PIL
441
+ pil_y = postprocess_image(y_tensor)# Remove batch dimension and convert to PIL
442
+ pil_y.save("debug_input_image.png") # Save the input image for debugging
443
+
444
+
445
+ print("Running SR inference...")
446
+ with torch.no_grad():
447
+ result_dict = POSTERIOR_MODEL.forward(y_tensor, current_inp_kwargs)
448
+
449
+ x_hat = result_dict["x_hat"]
450
+
451
+ print("Postprocessing SR result...")
452
+ output_pil = postprocess_image(x_hat)
453
+
454
+ # Upscale input image with nearest neighbor for comparison
455
+ upscaled_input = y_tensor.reshape(1,3,target_lr_height, target_lr_width)
456
+ upscaled_input = POSTERIOR_MODEL.forward_operator.nn(upscaled_input) # Use nearest neighbor upscaling
457
+ upscaled_input = postprocess_image(upscaled_input)
458
+ # save for debugging purposes
459
+ upscaled_input.save("debug_upscaled_input.png") # Save the upscaled input image for debugging
460
+ # upscaled_input = upscaled_input.resize((hr_resolution, hr_resolution), resample=Image.NEAREST)
461
+ return (upscaled_input, output_pil), current_sr_config["seed"]
462
+
463
+ except gr.Error as e:
464
+ raise
465
+ except Exception as e:
466
+ print(f"Error during super-resolution: {e}")
467
+ import traceback
468
+ traceback.print_exc()
469
+ raise gr.Error(f"An error occurred during super-resolution: {str(e)}. Check console for details.")
470
+
471
+
472
+ # Input for seed, allowing users to set it or leave it blank for random/config default
473
+ # Determine default num_steps from BASE_CONFIG if available
474
+ default_num_steps = 50 # Fallback default
475
+ if BASE_CONFIG is not None: # Check if BASE_CONFIG has been initialized
476
+ default_num_steps = BASE_CONFIG.get("num_steps", BASE_CONFIG.get("solver_kwargs", {}).get("num_steps", 50))
477
+
478
+ def superres_preview_preprocess(pil_image, resolution=768):
479
+ if pil_image is None:
480
+ return None
481
+ if pil_image.mode != "RGB":
482
+ pil_image = pil_image.convert("RGB")
483
+ # check if image is smaller than resolution
484
+ original_width, original_height = pil_image.size
485
+ if original_width < resolution or original_height < resolution:
486
+ return pil_image # No resizing needed, return original image
487
+ else:
488
+ pil_image = TF.center_crop(pil_image, [resolution, resolution])
489
+ return pil_image
490
+
491
+
492
+ # Dynamically load examples from demo_images directory
493
+ example_list_inp = []
494
+ example_list_sr = []
495
+ demo_images_dir = os.path.join(project_root, "static/demo_images")
496
+
497
+ if os.path.exists(demo_images_dir):
498
+ filenames = sorted(os.listdir(demo_images_dir))
499
+ processed_bases = set()
500
+ for filename in filenames:
501
+ if filename.startswith("demo_") and filename.endswith("_meta.json"):
502
+ base_name = filename[:-len("_meta.json")] # e.g., "demo_0"
503
+ if base_name in processed_bases:
504
+ continue
505
+
506
+ meta_path = os.path.join(demo_images_dir, filename)
507
+ image_filename = f"{base_name}_image.png"
508
+ image_path = os.path.join(demo_images_dir, image_filename)
509
+ mask_filename = f"{base_name}_mask.png"
510
+ mask_path = os.path.join(demo_images_dir, mask_filename)
511
+
512
+ if os.path.exists(image_path):
513
+ try:
514
+ with open(meta_path, 'r') as f:
515
+ metadata = json.load(f)
516
+ task = metadata.get("task_type")
517
+ prompt = metadata.get("prompt", "")
518
+ if task == "Super Resolution":
519
+ example_list_sr.append([image_path, prompt, task])
520
+ else:
521
+ image_editor_input = {
522
+ "background": image_path,
523
+ "layers": [mask_path],
524
+ "composite": None # Add this key to satisfy ImageEditor's as_example processing
525
+ }
526
+ example_list_inp.append([image_editor_input, prompt, task])
527
+
528
+ # Structure for ImageEditor: { "background": filepath, "layers": [filepath], "composite": None }
529
+
530
+ except json.JSONDecodeError:
531
+ print(f"Warning: Could not decode JSON from {meta_path}. Skipping example {base_name}.")
532
+ except Exception as e:
533
+ print(f"Warning: Error processing example {base_name}: {e}. Skipping.")
534
+ else:
535
+ missing_files = []
536
+ if not os.path.exists(image_path):
537
+ missing_files.append(image_filename)
538
+ if not os.path.exists(mask_path):
539
+ missing_files.append(mask_filename)
540
+ print(f"Warning: Missing files for example {base_name} ({', '.join(missing_files)}). Skipping.")
541
+ else:
542
+ print(f"Info: 'demo_images' directory not found at {demo_images_dir}. No dynamic examples will be loaded.")
543
+
544
+
545
+ if __name__ == "__main__":
546
+ if not os.path.exists(CONFIG_FILE_PATH):
547
+ print(f"ERROR: Configuration file not found at {CONFIG_FILE_PATH}")
548
+ sys.exit(1)
549
+
550
+ initialize_globals()
551
+
552
+ if MODEL is None or POSTERIOR_MODEL is None:
553
+ print("ERROR: Global model initialization failed.")
554
+ sys.exit(1)
555
+
556
+ # --- Define Gradio UI using gr.Blocks after globals are initialized ---
557
+ title_str = "Solving Inverse Problems with FLAIR: Inpainting Demo"
558
+ description_str = """
559
+ Select a task (Inpainting or Super Resolution) and upload an image.
560
+ For Inpainting, draw a mask on the image to specify the area to be filled. We observed that our model can event solve simple editing task, if provided with an appropriate prompt. For large masks the step size might need to be adjusted to e.g. 80.
561
+ For Super Resolution, upload a low-resolution image and select the upscaling factor. Images are always upscaled to 768x768 pixels. Therefore, for x12 superresolution, the input image must be 64x64 pixels. You can also upload a high resolution image which will be downscaled to the correct input size.
562
+ Use the slider to compare the low resolution input image with the super-resolved output.
563
+
564
+ """
565
+
566
+ # Determine default values now that BASE_CONFIG is initialized
567
+ default_num_steps = BASE_CONFIG.get("num_steps", BASE_CONFIG.get("solver_kwargs", {}).get("num_steps", 50))
568
+ default_guidance_scale = BASE_CONFIG.get("guidance", 2.0)
569
+
570
+ with gr.Blocks() as iface:
571
+ gr.Markdown(f"## {title_str}")
572
+ gr.Markdown(description_str)
573
+
574
+ task_selector = gr.Dropdown(
575
+ choices=["Inpainting", "Super Resolution"],
576
+ value="Inpainting",
577
+ label="Task"
578
+ )
579
+
580
+ with gr.Row():
581
+ with gr.Column(scale=1): # Input column
582
+ # Inpainting Inputs
583
+ image_editor = gr.ImageEditor(
584
+ type="pil",
585
+ label="Upload Image & Draw Mask (for Inpainting)",
586
+ sources=["upload"],
587
+ height=512,
588
+ width=512,
589
+ visible=True
590
+ )
591
+
592
+ # Super Resolution Inputs
593
+ image_input = gr.Image(
594
+ type="pil",
595
+ label="Upload Low-Resolution Image (for Super Resolution)",
596
+ visible=False
597
+ )
598
+
599
+ sr_scale_slider = gr.Dropdown(
600
+ choices=[2, 4, 8, 12, 24],
601
+ value=12,
602
+ label="Upscaling Factor (Super Resolution)",
603
+ interactive=True,
604
+ visible=False # Initially hidden
605
+ )
606
+ downscale_input = gr.Checkbox(
607
+ label="Downscale the provided image.",
608
+ value=True,
609
+ interactive=True,
610
+ visible=False # Initially hidden
611
+ )
612
+
613
+ # Common Inputs
614
+ prompt_text = gr.Textbox(
615
+ label="Prompt",
616
+ placeholder="E.g., a beautiful landscape, a detailed portrait"
617
+ )
618
+ seed_slider = gr.Slider(
619
+ minimum=0,
620
+ maximum=2**32 -1, # Max for torch.manual_seed
621
+ step=1,
622
+ label="Seed (if not random)",
623
+ value=42,
624
+ interactive=True
625
+ )
626
+
627
+ use_random_seed_checkbox = gr.Checkbox(
628
+ label="Use Random Seed",
629
+ value=True,
630
+ interactive=True
631
+ )
632
+ guidance_scale_slider = gr.Slider(
633
+ minimum=1.0,
634
+ maximum=15.0,
635
+ step=0.5,
636
+ value=default_guidance_scale,
637
+ label="Guidance Scale"
638
+ )
639
+ num_steps_slider = gr.Slider(
640
+ minimum=28,
641
+ maximum=150,
642
+ step=1,
643
+ value=default_num_steps,
644
+ label="Number of Steps"
645
+ )
646
+ submit_button = gr.Button("Submit")
647
+
648
+ # # Add Save Configuration button and status text
649
+ # gr.Markdown("---") # Separator
650
+ # save_button = gr.Button("Save Current Configuration for Demo")
651
+ # save_status_text = gr.Markdown()
652
+
653
+ with gr.Column(scale=1): # Output column
654
+ output_image_display = gr.Image(type="pil", label="Result")
655
+ sr_compare_display = ImageSlider(label="Super-Resolution: Input vs Output", visible=False, position=0.5)
656
+
657
+
658
+
659
+
660
+ # --- Task routing and visibility logic ---
661
+ def update_visibility(task):
662
+ is_inpainting = task == "Inpainting"
663
+ is_super_resolution = task == "Super Resolution"
664
+ return {
665
+ image_editor: gr.update(visible=is_inpainting),
666
+ image_input: gr.update(visible=is_super_resolution),
667
+ sr_scale_slider: gr.update(visible=is_super_resolution),
668
+ downscale_input: gr.update(visible=is_super_resolution),
669
+ output_image_display: gr.update(visible=is_inpainting),
670
+ sr_compare_display: gr.update(visible=is_super_resolution, position=0.5),
671
+ downscale_input: gr.update(visible=is_super_resolution),
672
+ }
673
+
674
+ task_selector.change(
675
+ fn=update_visibility,
676
+ inputs=[task_selector],
677
+ outputs=[image_editor, image_input, sr_scale_slider, downscale_input, output_image_display, sr_compare_display]
678
+ )
679
+
680
+
681
+ # MODIFIED route_task to accept sr_scale_factor
682
+ def route_task(task, image_editor_data, lr_image_for_sr, prompt_text, fixed_seed_value, use_random_seed, guidance_scale, num_steps, sr_scale_factor_value, downscale_input):
683
+ if task == "Inpainting":
684
+ return inpaint_image(image_editor_data, prompt_text, fixed_seed_value, use_random_seed, guidance_scale, num_steps)
685
+ elif task == "Super Resolution":
686
+ result_images, seed_val = super_resolution_image(
687
+ lr_image_for_sr, prompt_text, fixed_seed_value, use_random_seed,
688
+ guidance_scale, num_steps, sr_scale_factor_value, downscale_input
689
+ )
690
+ return result_images[1], gr.update(value=result_images, position=0.5), seed_val
691
+ else:
692
+ raise gr.Error("Unsupported task.")
693
+
694
+ submit_button.click(
695
+ fn=route_task,
696
+ inputs=[
697
+ task_selector,
698
+ image_editor,
699
+ image_input,
700
+ prompt_text,
701
+ seed_slider,
702
+ use_random_seed_checkbox,
703
+ guidance_scale_slider,
704
+ num_steps_slider,
705
+ sr_scale_slider,
706
+ downscale_input,
707
+ ],
708
+ outputs=[
709
+ output_image_display,
710
+ sr_compare_display,
711
+ seed_slider
712
+ ]
713
+ )
714
+
715
+ # Wire up the save button
716
+ # save_button.click(
717
+ # fn=save_configuration,
718
+ # inputs=[
719
+ # image_editor,
720
+ # image_input,
721
+ # prompt_text,
722
+ # seed_slider,
723
+ # task_selector,
724
+ # use_random_seed_checkbox,
725
+ # num_steps_slider,
726
+ # ],
727
+ # outputs=[save_status_text]
728
+ # )
729
+
730
+
731
+ gr.Markdown("---") # Separator
732
+ gr.Markdown("### Click an example to load:")
733
+ def load_example(example_data, prompt, task):
734
+ image_editor_input = example_data[0]
735
+ prompt_value = example_data[1]
736
+ if task == "Inpainting":
737
+ image_editor.clear() # Clear current image and mask
738
+ if image_editor_input and image_editor_input.get("background"):
739
+ image_editor.upload_image(image_editor_input["background"])
740
+ if image_editor_input and image_editor_input.get("layers"):
741
+ for layer in image_editor_input["layers"]:
742
+ image_editor.upload_mask(layer)
743
+ elif task == "Super Resolution":
744
+ image_input.clear()
745
+ image_input.upload_image(image_editor_input)
746
+
747
+ # Set the prompt
748
+ prompt_text.value = prompt_value
749
+ # Optionally, set a random seed and guidance scale
750
+ seed_slider.value = random.randint(0, 2**32 - 1)
751
+ guidance_scale_slider.value = default_guidance_scale
752
+ # Set the task selector from the example
753
+ task_selector.set_value(task)
754
+ update_visibility(task) # Update visibility based on task
755
+
756
+ with gr.Row():
757
+ gr.Examples(
758
+ examples=example_list_sr,
759
+ inputs=[image_input, prompt_text, task_selector],
760
+ label="Super Resolution Examples",
761
+ fn=load_example,
762
+ )
763
+ with gr.Row():
764
+ gr.Examples(
765
+ examples=example_list_inp,
766
+ inputs=[image_editor, prompt_text, task_selector],
767
+ label="Inpainting Examples",
768
+ fn=load_example,
769
+ )
770
+
771
+ # --- End of Gradio UI definition ---
772
+
773
+ print("Launching Gradio demo...")
774
+ iface.launch()
assets/teaser3.svg ADDED
configs/inpainting.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # === Model & Data Settings ===
2
+ model: "SD3"
3
+ resolution: 768
4
+ lambda_func: v
5
+ optimized_reg_weight: ./LCFM/SD3_loss_v_MSE_DIV2k_neg_prompt.npy
6
+ regularizer_weight: 0.5
7
+ likelihood_weight_mode: reg_weight
8
+ likelihood_steps: 15
9
+ early_stopping: 1.e-4
10
+ epochs: 1
11
+ guidance: 2
12
+ quantize: False
13
+ use_tiny_ae: True
14
+ negative_prompt: ""
15
+ reg-shift: 0.0
16
+
17
+ # === Optimization Settings ===
18
+ optimizer:
19
+ name: SGD
20
+ kwargs:
21
+ lr: 1
22
+ optimizer_dataterm:
23
+ name: SGD
24
+ kwargs:
25
+ lr: 0.1
26
+
27
+ # === Sampling & Misc ===
28
+ t_sampling: descending
29
+ n_steps: 50
30
+ inv_alpha: 1-t
31
+ ts_min: 0.18
32
+ projection: False
33
+ seed: 42
34
+
35
+ # === Experiment-Specific Settings ===
36
+ prompt: A high quality photo of a girl with a pirate eye-patch.
37
+ degradation:
38
+ name: Inpainting
39
+ kwargs:
40
+ mask: [128, 640, 384, 640]
41
+ H: 768
42
+ W: 768
43
+ noise_std: 0.01
44
+ likelihood_weight: 1
45
+
configs/inpainting_gradio.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # === Model & Data Settings ===
2
+ model: "SD3"
3
+ resolution: 512
4
+ lambda_func: v
5
+ optimized_reg_weight: ./LCFM/SD3_loss_v_MSE_DIV2k_neg_prompt.npy
6
+ regularizer_weight: 0.5
7
+ likelihood_weight_mode: reg_weight
8
+ likelihood_steps: 15
9
+ early_stopping: 1.e-4
10
+ epochs: 1
11
+ guidance: 2
12
+ quantize: False
13
+ use_tiny_ae: True
14
+ negative_prompt: ""
15
+ reg-shift: 0.0
16
+
17
+ # === Optimization Settings ===
18
+ optimizer:
19
+ name: SGD
20
+ kwargs:
21
+ lr: 1
22
+ optimizer_dataterm:
23
+ name: SGD
24
+ kwargs:
25
+ lr: 0.1
26
+
27
+ # === Sampling & Misc ===
28
+ t_sampling: descending
29
+ n_steps: 50
30
+ inv_alpha: 1-t
31
+ ts_min: 0.18
32
+ projection: False
33
+ seed: 42
34
+
35
+ # === Experiment-Specific Settings ===
36
+ prompt: A high quality photo of a girl with a pirate eye-patch.
37
+ degradation:
38
+ name: Inpainting
39
+ kwargs:
40
+ mask: [128, 640, 384, 640]
41
+ H: 512
42
+ W: 512
43
+ noise_std: 0.0
44
+ likelihood_weight: 1
45
+
configs/motion_deblur.yaml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # === Model & Data Settings ===
2
+ model: "SD3"
3
+ resolution: 768
4
+ lambda_func: v
5
+ optimized_reg_weight: SD3_loss_v_norm_div2k_squared.npy
6
+ regularizer_weight: 0.5
7
+ likelihood_weight_mode: reg_weight
8
+ likelihood_steps: 15
9
+ early_stopping: 1.e-4
10
+ epochs: 1
11
+ guidance: 2
12
+ quantize: False
13
+ use_tiny_ae: True
14
+ negative_prompt: ""
15
+ reg-shift: 0.0
16
+
17
+ # === Optimization Settings ===
18
+ optimizer:
19
+ name: SGD
20
+ kwargs:
21
+ lr: 1
22
+ optimizer_dataterm:
23
+ name: SGD
24
+ kwargs:
25
+ lr: 0.1
26
+
27
+ # === Sampling & Misc ===
28
+ t_sampling: descending
29
+ n_steps: 50
30
+ inv_alpha: 1-t
31
+ ts_min: 0.18
32
+ projection: False
33
+ seed: 42
34
+
35
+ # === Experiment-Specific Settings ===
36
+ prompt: a high quality photo of a face
37
+ degradation:
38
+ name: MotionBlur
39
+ kwargs:
40
+ kernel_size: 61
41
+ noise_std: 0.01
42
+ img_size: 768
43
+ likelihood_weight: 1
configs/x12.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # OmegaConf base config for neurips experiments
2
+ # Place common settings here. Each experiment config will override as needed.
3
+
4
+ # === Model & Data Settings ===
5
+ model: "SD3"
6
+ resolution: 768
7
+ lambda_func: v
8
+ optimized_reg_weight: ./LCFM/SD3_loss_v_MSE_DIV2k_neg_prompt.npy
9
+ regularizer_weight: 0.5
10
+ likelihood_weight_mode: reg_weight
11
+ likelihood_steps: 15
12
+ early_stopping: 1.e-4
13
+ epochs: 1
14
+ guidance: 2
15
+ quantize: False
16
+ use_tiny_ae: True
17
+ negative_prompt: ""
18
+ reg-shift: 0.0
19
+
20
+ # === Optimization Settings ===
21
+ optimizer:
22
+ name: SGD
23
+ kwargs:
24
+ lr: 1
25
+ optimizer_dataterm:
26
+ name: SGD
27
+ kwargs:
28
+ lr: 12
29
+
30
+ # === Sampling & Misc ===
31
+ t_sampling: descending
32
+ n_steps: 50
33
+ inv_alpha: 1-t
34
+ quantize: False
35
+ ts_min: 0.18
36
+ projection: False
37
+ seed: 3
38
+
39
+ # === Experiment-Specific Settings ===
40
+ prompt: a high quality photo of
41
+ # Super-resolution x12 for FFHQ
42
+ degradation:
43
+ name: SuperRes
44
+ kwargs:
45
+ scale: 12
46
+ noise_std: 0.01
47
+ img_size: 768
48
+ likelihood_weight: 1
configs/x12_gradio.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # OmegaConf base config for neurips experiments
2
+ # Place common settings here. Each experiment config will override as needed.
3
+
4
+ # === Model & Data Settings ===
5
+ model: "SD3"
6
+ resolution: 768
7
+ lambda_func: v
8
+ optimized_reg_weight: ./LCFM/SD3_loss_v_MSE_DIV2k_neg_prompt.npy
9
+ regularizer_weight: 0.5
10
+ likelihood_weight_mode: reg_weight
11
+ likelihood_steps: 15
12
+ early_stopping: 1.e-4
13
+ epochs: 1
14
+ guidance: 2
15
+ quantize: False
16
+ use_tiny_ae: True
17
+ negative_prompt: ""
18
+ reg-shift: 0.0
19
+
20
+ # === Optimization Settings ===
21
+ optimizer:
22
+ name: SGD
23
+ kwargs:
24
+ lr: 1
25
+ optimizer_dataterm:
26
+ name: SGD
27
+ kwargs:
28
+ lr: 12
29
+
30
+ # === Sampling & Misc ===
31
+ t_sampling: descending
32
+ n_steps: 50
33
+ inv_alpha: 1-t
34
+ quantize: False
35
+ ts_min: 0.18
36
+ projection: False
37
+ seed: 3
38
+
39
+ # === Experiment-Specific Settings ===
40
+ prompt: a high quality photo of
41
+ # Super-resolution x12 for FFHQ
42
+ degradation:
43
+ name: SuperRes
44
+ kwargs:
45
+ scale: 12
46
+ noise_std: 0.0
47
+ img_size: 768
48
+ likelihood_weight: 1
demo_images/demo_0_image.png ADDED

Git LFS Details

  • SHA256: d2ba059669ab9f2a7aa855a0a8980cfd9f0871ddfea9d8e995a912fd8a35b56a
  • Pointer size: 131 Bytes
  • Size of remote file: 578 kB
demo_images/demo_0_meta.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "prompt": "a high quality image of a face.",
3
+ "seed_on_slider": 2646030500,
4
+ "use_random_seed_checkbox": true,
5
+ "num_steps": 50,
6
+ "task_type": "Super Resolution"
7
+ }
demo_images/demo_1_image.png ADDED

Git LFS Details

  • SHA256: cab6e98eb2fff0bd4750e0f390a941a32cc1fb52e3414aaa61510120c7a6c28b
  • Pointer size: 132 Bytes
  • Size of remote file: 5.95 MB
demo_images/demo_1_mask.png ADDED

Git LFS Details

  • SHA256: f6be461659314edf39c1fc613e10679b31e76a2bfc951df5c3de61f14cbea670
  • Pointer size: 130 Bytes
  • Size of remote file: 13.7 kB
demo_images/demo_1_meta.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "prompt": "a house",
3
+ "seed_on_slider": 434714511,
4
+ "use_random_seed_checkbox": false,
5
+ "num_steps": 50,
6
+ "task_type": "Inpainting"
7
+ }
demo_images/demo_2_image.png ADDED

Git LFS Details

  • SHA256: a728ea3db369044155c037227c48ba03287926521df02ccaa3d1410d1cd161f8
  • Pointer size: 132 Bytes
  • Size of remote file: 1.69 MB
demo_images/demo_2_mask.png ADDED

Git LFS Details

  • SHA256: 75ddf85acc020c6ad069309f3a2b4cdcefa24800abfddcc910b25b7337949ace
  • Pointer size: 129 Bytes
  • Size of remote file: 6.73 kB
demo_images/demo_2_meta.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "prompt": "a high quality image of a face.",
3
+ "seed_on_slider": 3211750901,
4
+ "use_random_seed_checkbox": false,
5
+ "num_steps": 50,
6
+ "task_type": "Inpainting"
7
+ }
demo_images/demo_3_image.png ADDED

Git LFS Details

  • SHA256: f9b5c8a1043259b8d86e0502ad6139158145451a13c56723dd4671023402d0b5
  • Pointer size: 132 Bytes
  • Size of remote file: 1.64 MB
demo_images/demo_3_mask.png ADDED

Git LFS Details

  • SHA256: 5fccd5046c9861df385cdbfa80e266d2eaca74e5ccca4295e253ceb20c0af97b
  • Pointer size: 130 Bytes
  • Size of remote file: 16.4 kB
demo_images/demo_3_meta.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "prompt": "a girl with bright blue eyes ",
3
+ "seed_on_slider": 4257757956,
4
+ "use_random_seed_checkbox": true,
5
+ "num_steps": 50,
6
+ "task_type": "Inpainting"
7
+ }
examples/girl.png ADDED

Git LFS Details

  • SHA256: cbede299ec1bb921d9ea60b6b6b5dcb3f8bb6025f7dc163c7af744089e011f71
  • Pointer size: 132 Bytes
  • Size of remote file: 1.46 MB
examples/sunflowers.png ADDED

Git LFS Details

  • SHA256: 32d749daec6c56584e6c9c32432e49b315153f7e74375caab5363bbeb33d9e18
  • Pointer size: 132 Bytes
  • Size of remote file: 4.26 MB
inference_scripts/run_image_inv.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 <Julius Erbach ETH Zurich>
2
+ #
3
+ # This file is part of the var_post_samp project and is licensed under the MIT License.
4
+ # See the LICENSE file in the project root for more information.
5
+ """
6
+ Usage:
7
+ python run_image_inv.py --config <config.yaml>
8
+ """
9
+
10
+ import os
11
+ import sys
12
+ import time
13
+ import csv
14
+ import yaml
15
+ import torch
16
+ import random
17
+ import click
18
+ import numpy as np
19
+ import tqdm
20
+ import datetime
21
+ import torchvision
22
+
23
+ from flair.helper_functions import parse_click_context
24
+ from flair.pipelines import model_loader
25
+ from flair.utils import data_utils
26
+ from flair import var_post_samp
27
+
28
+
29
+ dtype = torch.bfloat16
30
+
31
+ num_gpus = torch.cuda.device_count()
32
+ if num_gpus > 0:
33
+ devices = [f"cuda:{i}" for i in range(num_gpus)]
34
+ primary_device = devices[0]
35
+ print(f"Using devices: {devices}")
36
+ print(f"Primary device for operations: {primary_device}")
37
+ else:
38
+ print("No CUDA devices found. Using CPU.")
39
+ devices = ["cpu"]
40
+ primary_device = "cpu"
41
+
42
+
43
+ @click.command(context_settings={"allow_extra_args": True, "ignore_unknown_options": True})
44
+ @click.option("--config", "config_file_arg", type=click.Path(exists=True), help="Path to the config file")
45
+ @click.option("--target_file", type=click.Path(exists=True), help="Path to the target file or folder")
46
+ @click.option("--result_folder", type=click.Path(file_okay=False, dir_okay=True, writable=True, resolve_path=True), help="Path to the output folder. It will be created if it doesn't exist.")
47
+ @click.option("--mask_file", type=click.Path(exists=True), default=None, help="Path to the mask file npy. Optional used for image inpainting. True pixels are observed.")
48
+ @click.pass_context
49
+ def main(ctx, config_file_arg, target_file, result_folder, mask_file=None):
50
+ """Main entry point for image inversion and sampling.
51
+
52
+ The user must provide either a caption_file (with per-image captions) OR a single prompt for all images in the config YAML file.
53
+ """
54
+ with open(config_file_arg, "r") as f:
55
+ config = yaml.safe_load(f)
56
+ ctx = parse_click_context(ctx)
57
+ config.update(ctx)
58
+ # Read caption_file and prompt from config
59
+ caption_file = config.get("caption_file", None)
60
+ prompt = config.get("prompt", None)
61
+
62
+ # Enforce mutually exclusive caption_file or prompt
63
+ if (not caption_file and not prompt) or (caption_file and prompt):
64
+ raise ValueError("You must provide either 'caption_file' OR 'prompt' (not both) in the config file. See documentation.")
65
+
66
+ # wandb removed, so config_dict is just a copy
67
+ config_dict = dict(config)
68
+ torch.manual_seed(config["seed"])
69
+ np.random.seed(config["seed"])
70
+ random.seed(config["seed"])
71
+
72
+ # Use config values as-is (no to_absolute_path)
73
+ caption_file = caption_file if caption_file else None
74
+
75
+ guidance_img_iterator = data_utils.yield_images(
76
+ target_file, size=config["resolution"]
77
+ )
78
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
79
+ counter = 1
80
+ name = f'results_{config["model"]}_{config["degradation"]["name"]}_resolution_{config["resolution"]}_noise_{config["degradation"]["kwargs"]["noise_std"]}_{timestamp}'
81
+ candidate = os.path.join(name)
82
+ while os.path.exists(candidate):
83
+ candidate = os.path.join(f"{name}_{counter}")
84
+ counter += 1
85
+ output_folders = data_utils.generate_output_structure(
86
+ result_folder,
87
+ [
88
+ candidate,
89
+ f'input_{config["degradation"]["name"]}_resolution_{config["resolution"]}_noise_{config["degradation"]["kwargs"]["noise_std"]}',
90
+ f'target_{config["degradation"]["name"]}_resolution_{config["resolution"]}_noise_{config["degradation"]["kwargs"]["noise_std"]}',
91
+ ],
92
+ )
93
+ config_out = os.path.join(os.path.split(output_folders[0])[0], "config.yaml")
94
+ with open(config_out, "w") as f:
95
+ yaml.safe_dump(config_dict, f)
96
+
97
+ source_files = list(data_utils.find_files(target_file, ext="png"))
98
+ num_images = len(source_files)
99
+ print(f"Found {num_images} images.")
100
+
101
+ # Load captions
102
+ if caption_file:
103
+ captions = data_utils.load_captions_from_file(caption_file, user_prompt="")
104
+ if not captions:
105
+ sys.exit("Error: No captions were loaded from the provided caption file.")
106
+ if len(captions) != num_images:
107
+ print("Warning: Number of captions does not match number of images.")
108
+ prompts_in_order = [captions.get(os.path.basename(f), "") for f in source_files]
109
+ else:
110
+ # Use the single prompt for all images
111
+ prompts_in_order = [prompt for _ in range(num_images)]
112
+
113
+ if any(p == "" for p in prompts_in_order):
114
+ print("Warning: Some images might not have corresponding captions or prompt is empty.")
115
+ config["prompt"] = prompts_in_order
116
+
117
+ model, inp_kwargs = model_loader.load_model(config, device=devices)
118
+ if mask_file and config["degradation"]["name"] == "Inpainting":
119
+ config["degradation"]["kwargs"]["mask"] = mask_file
120
+ posterior_model = var_post_samp.VariationalPosterior(model, config)
121
+ guidance_img_iterator = data_utils.yield_images(
122
+ target_file, size=config["resolution"]
123
+ )
124
+ for idx, guidance_img in tqdm.tqdm(enumerate(guidance_img_iterator), total=num_images):
125
+ guidance_img = guidance_img.to(dtype).cuda()
126
+ y = posterior_model.forward_operator(guidance_img)
127
+ tic = time.time()
128
+ with torch.no_grad():
129
+ result_dict = posterior_model.forward(y, inp_kwargs[idx])
130
+ x_hat = result_dict["x_hat"]
131
+ toc = time.time()
132
+ print(f"Runtime: {toc - tic}")
133
+ guidance_img = guidance_img.cuda()
134
+ result_file = output_folders[0].format(idx)
135
+ input_file = output_folders[1].format(idx)
136
+ ground_truth_file = output_folders[2].format(idx)
137
+ x_hat_pil = torchvision.transforms.ToPILImage()(
138
+ x_hat.float()[0].clip(-1, 1) * 0.5 + 0.5
139
+ )
140
+ x_hat_pil.save(result_file)
141
+ try:
142
+ if config["degradation"]["name"] == "SuperRes":
143
+ input_img = posterior_model.forward_operator.nn(y)
144
+ else:
145
+ input_img = posterior_model.forward_operator.pseudo_inv(y)
146
+ input_img_pil = torchvision.transforms.ToPILImage()(
147
+ input_img.float()[0].clip(-1, 1) * 0.5 + 0.5
148
+ )
149
+ input_img_pil.save(input_file)
150
+ except Exception:
151
+ print("Error in pseudo-inverse operation. Skipping input image save.")
152
+ guidance_img_pil = torchvision.transforms.ToPILImage()(
153
+ guidance_img.float()[0] * 0.5 + 0.5
154
+ )
155
+ guidance_img_pil.save(ground_truth_file)
156
+
157
+
158
+ if __name__ == "__main__":
159
+ main()
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ tqdm
2
+ torch
3
+ click
4
+ pyyaml
5
+ numpy
6
+ torchvision
7
+ diffusers
8
+ Pillow
9
+ scipy
10
+ munch
11
+ transformers
12
+ torchmetrics
13
+ scikit-image
14
+ opencv-python
15
+ sentencepiece
16
+ protobuf
17
+ accelerate
18
+ gradio
scripts/compute_metrics.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchmetrics
2
+ import torch
3
+ from PIL import Image
4
+ import argparse
5
+ from flair.utils import data_utils
6
+ import os
7
+ import tqdm
8
+ import torch.nn.functional as F
9
+ from torchmetrics.image.kid import KernelInceptionDistance
10
+
11
+
12
+ MAX_BATCH_SIZE = None
13
+
14
+ @torch.no_grad()
15
+ def main(args):
16
+ # Determine device
17
+ if args.device == "cuda" and torch.cuda.is_available():
18
+ device = torch.device("cuda")
19
+ else:
20
+ device = torch.device("cpu")
21
+ print(f"Using device: {device}")
22
+
23
+ # load images
24
+ gt_iterator = data_utils.yield_images(os.path.abspath(args.gt), size=args.resolution)
25
+ pred_iterator = data_utils.yield_images(os.path.abspath(args.pred), size=args.resolution)
26
+ fid_metric = torchmetrics.image.fid.FrechetInceptionDistance(normalize=True).to(device)
27
+ # kid_metric = KernelInceptionDistance(subset_size=args.kid_subset_size, normalize=True).to(device)
28
+ lpips_metric = torchmetrics.image.lpip.LearnedPerceptualImagePatchSimilarity(
29
+ net_type="alex", normalize=False, reduction="mean"
30
+ ).to(device)
31
+ if args.patch_size:
32
+ patch_fid_metric = torchmetrics.image.fid.FrechetInceptionDistance(normalize=True).to(device)
33
+ # patch_kid_metric = KernelInceptionDistance(subset_size=args.kid_subset_size, normalize=True).to(device)
34
+ psnr_list = []
35
+ lpips_list = []
36
+ ssim_list = []
37
+ # iterate over images
38
+ for gt, pred in tqdm.tqdm(zip(gt_iterator, pred_iterator)):
39
+ # Move tensors to the selected device
40
+ gt = gt.to(device)
41
+ pred = pred.to(device)
42
+
43
+ # resize gt to pred size
44
+ if gt.shape[-2:] != (args.resolution, args.resolution):
45
+ gt = F.interpolate(gt, size=args.resolution, mode="area")
46
+ if pred.shape[-2:] != (args.resolution, args.resolution):
47
+ pred = F.interpolate(pred, size=args.resolution, mode="area")
48
+ # to range [0,1]
49
+ gt_norm = gt * 0.5 + 0.5
50
+ pred_norm = pred * 0.5 + 0.5
51
+ # compute PSNR
52
+ psnr = torchmetrics.functional.image.peak_signal_noise_ratio(
53
+ pred_norm, gt_norm, data_range=1.0
54
+ )
55
+ psnr_list.append(psnr.cpu()) # Move result to CPU
56
+ # compute LPIPS
57
+ lpips_score = lpips_metric(pred.clip(-1,1), gt.clip(-1,1))
58
+ lpips_list.append(lpips_score.cpu()) # Move result to CPU
59
+ # compute SSIM
60
+ ssim = torchmetrics.functional.image.structural_similarity_index_measure(
61
+ pred_norm, gt_norm, data_range=1.0
62
+ )
63
+ ssim_list.append(ssim.cpu()) # Move result to CPU
64
+ print(f"PSNR: {psnr}, LPIPS: {lpips_score}, SSIM: {ssim}")
65
+ # compute FID
66
+ # Ensure inputs are on the correct device (already handled by moving gt/pred earlier)
67
+ fid_metric.update(gt_norm, real=False)
68
+ fid_metric.update(pred_norm, real=True)
69
+ # compute KID
70
+ # kid_metric.update(pred, real=False)
71
+ # kid_metric.update(gt, real=True)
72
+ # compute Patchwise FID/KID if patch_size is specified
73
+ if args.patch_size:
74
+ # Extract patches
75
+ patch_size = args.patch_size
76
+ gt_patches = F.unfold(gt_norm, kernel_size=patch_size, stride=patch_size)
77
+ pred_patches = F.unfold(pred_norm, kernel_size=patch_size, stride=patch_size)
78
+ # Reshape patches: (B, C*P*P, N_patches) -> (B*N_patches, C, P, P)
79
+ B, C, H, W = gt.shape
80
+ N_patches = gt_patches.shape[-1]
81
+ gt_patches = gt_patches.permute(0, 2, 1).reshape(B * N_patches, C, patch_size, patch_size)
82
+ pred_patches = pred_patches.permute(0, 2, 1).reshape(B * N_patches, C, patch_size, patch_size)
83
+ # Update patch FID metric (inputs are already on the correct device)
84
+ # Update patch KID metric
85
+ # process mini batches of patches
86
+ if MAX_BATCH_SIZE is None:
87
+ patch_fid_metric.update(pred_patches, real=False)
88
+ patch_fid_metric.update(gt_patches, real=True)
89
+ # patch_kid_metric.update(pred_patches, real=False)
90
+ # patch_kid_metric.update(gt_patches, real=True)
91
+ else:
92
+ for i in range(0, N_patches, MAX_BATCH_SIZE):
93
+ patch_fid_metric.update(pred_patches[i:i + MAX_BATCH_SIZE], real=False)
94
+ patch_fid_metric.update(gt_patches[i:i + MAX_BATCH_SIZE], real=True)
95
+ # patch_kid_metric.update(pred_patches[i:i + MAX_BATCH_SIZE], real=False)
96
+ # patch_kid_metric.update(gt_patches[i:i + MAX_BATCH_SIZE], real=True)
97
+
98
+ # compute FID
99
+ fid = fid_metric.compute()
100
+ # compute KID
101
+ # kid_mean, kid_std = kid_metric.compute()
102
+ if args.patch_size:
103
+ patch_fid = patch_fid_metric.compute()
104
+ # patch_kid_mean, patch_kid_std = patch_kid_metric.compute()
105
+ # compute average metrics (on CPU)
106
+ avg_psnr = torch.mean(torch.stack(psnr_list))
107
+ avg_lpips = torch.mean(torch.stack(lpips_list))
108
+ avg_ssim = torch.mean(torch.stack(ssim_list))
109
+ # compute standard deviation (on CPU)
110
+ std_psnr = torch.std(torch.stack(psnr_list))
111
+ std_lpips = torch.std(torch.stack(lpips_list))
112
+ std_ssim = torch.std(torch.stack(ssim_list))
113
+ print(f"PSNR: {avg_psnr} +/- {std_psnr}")
114
+ print(f"LPIPS: {avg_lpips} +/- {std_lpips}")
115
+ print(f"SSIM: {avg_ssim} +/- {std_ssim}")
116
+ print(f"FID: {fid}") # FID is computed on the selected device, print directly
117
+ # print(f"KID: {kid_mean} +/- {kid_std}") # KID is computed on the selected device, print directly
118
+ if args.patch_size:
119
+ print(f"Patch FID ({args.patch_size}x{args.patch_size}): {patch_fid}") # Patch FID is computed on the selected device, print directly
120
+ # print(f"Patch KID ({args.patch_size}x{args.patch_size}): {patch_kid_mean} +/- {patch_kid_std}") # Patch KID is computed on the selected device, print directly
121
+ # save to prediction folder
122
+ out_file = os.path.join(args.pred, "fid_metrics.txt")
123
+ with open(out_file, "w") as f:
124
+ f.write(f"PSNR: {avg_psnr.item()} +/- {std_psnr.item()}\n") # Use .item() for scalar tensors
125
+ f.write(f"LPIPS: {avg_lpips.item()} +/- {std_lpips.item()}\n")
126
+ f.write(f"SSIM: {avg_ssim.item()} +/- {std_ssim.item()}\n")
127
+ f.write(f"FID: {fid.item()}\n") # Use .item() for scalar tensors
128
+ # f.write(f"KID: {kid_mean.item()} +/- {kid_std.item()}\n") # Use .item() for scalar tensors
129
+ if args.patch_size:
130
+ f.write(f"Patch FID ({args.patch_size}x{args.patch_size}): {patch_fid.item()}\n") # Use .item() for scalar tensors
131
+ # f.write(f"Patch KID ({args.patch_size}x{args.patch_size}): {patch_kid_mean.item()} +/- {patch_kid_std.item()}\n") # Use .item() for scalar tensors
132
+
133
+
134
+ if __name__ == "__main__":
135
+ parser = argparse.ArgumentParser(description="Compute metrics")
136
+ parser.add_argument("--gt", type=str, help="Path to ground truth image")
137
+ parser.add_argument("--pred", type=str, help="Path to predicted image")
138
+ parser.add_argument("--resolution", type=int, default=768, help="resolution at which to evaluate")
139
+ parser.add_argument("--patch_size", type=int, default=None, help="Patch size for Patchwise FID/KID computation (e.g., 12). If None, skip.")
140
+ parser.add_argument("--kid_subset_size", type=int, default=1000, help="Subset size for KID computation.")
141
+ parser.add_argument("--device", type=str, default="cpu", choices=["cpu", "cuda"], help="Device to run computation on (cpu or cuda)")
142
+ args = parser.parse_args()
143
+
144
+ main(args)
scripts/generate_caption.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import argparse
4
+ from PIL import Image
5
+
6
+ # Add the path to the thirdparty/SeeSR directory to the Python path
7
+ sys.path.append(os.path.abspath("./thirdparty/SeeSR"))
8
+
9
+ import torch
10
+ from torchvision import transforms
11
+ from ram.models.ram_lora import ram
12
+ from ram import inference_ram as inference
13
+
14
+ def load_ram_model(ram_model_path: str, dape_model_path: str):
15
+ """
16
+ Load the RAM model with the given paths.
17
+
18
+ Args:
19
+ ram_model_path (str): Path to the pretrained RAM model.
20
+ dape_model_path (str): Path to the pretrained DAPE model.
21
+
22
+ Returns:
23
+ torch.nn.Module: Loaded RAM model.
24
+ """
25
+ device = "cuda" if torch.cuda.is_available() else "cpu"
26
+
27
+ # Load the RAM model
28
+ tag_model = ram(pretrained=ram_model_path, pretrained_condition=dape_model_path, image_size=384, vit="swin_l")
29
+ tag_model.eval()
30
+ return tag_model.to(device)
31
+
32
+ def generate_caption(image_path: str, tag_model) -> str:
33
+ """
34
+ Generate a caption for a degraded image using the RAM model.
35
+
36
+ Args:
37
+ image_path (str): Path to the degraded input image.
38
+ tag_model (torch.nn.Module): Preloaded RAM model.
39
+
40
+ Returns:
41
+ str: Generated caption for the image.
42
+ """
43
+ device = "cuda" if torch.cuda.is_available() else "cpu"
44
+
45
+ # Define image transformations
46
+ tensor_transforms = transforms.Compose([
47
+ transforms.ToTensor(),
48
+ ])
49
+ ram_transforms = transforms.Compose([
50
+ transforms.Resize((384, 384)),
51
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
52
+ ])
53
+
54
+ # Load and preprocess the image
55
+ image = Image.open(image_path).convert("RGB")
56
+ image_tensor = tensor_transforms(image).unsqueeze(0).to(device)
57
+ image_tensor = ram_transforms(image_tensor)
58
+
59
+ # Generate caption using the RAM model
60
+ caption = inference(image_tensor, tag_model)
61
+
62
+ return caption[0]
63
+
64
+ def process_images_in_directory(input_dir: str, output_file: str, tag_model):
65
+ """
66
+ Process all images in a directory, generate captions using the RAM model,
67
+ and save the captions to a file.
68
+
69
+ Args:
70
+ input_dir (str): Path to the directory containing input images.
71
+ output_file (str): Path to the file where captions will be saved.
72
+ tag_model (torch.nn.Module): Preloaded RAM model.
73
+ """
74
+ # Open the output file for writing captions
75
+ with open(output_file, "w") as f:
76
+ # Iterate through all files in the input directory
77
+ for filename in os.listdir(input_dir):
78
+ # Construct the full path to the image file
79
+ image_path = os.path.join(input_dir, filename)
80
+
81
+ # Check if the file is an image
82
+ if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
83
+ try:
84
+ # Generate a caption for the image
85
+ caption = generate_caption(image_path, tag_model)
86
+ print(f"Generated caption for {filename}: {caption}")
87
+ # Write the caption to the output file
88
+ f.write(f"{filename}: {caption}\n")
89
+
90
+ print(f"Processed {filename}: {caption}")
91
+ except Exception as e:
92
+ print(f"Error processing {filename}: {e}")
93
+
94
+ if __name__ == "__main__":
95
+ parser = argparse.ArgumentParser(description="Generate captions for images using RAM and DAPE models.")
96
+ parser.add_argument("--input_dir", type=str, default="data/val", help="Path to the directory containing input images.")
97
+ parser.add_argument("--output_file", type=str, default="data/val_captions.txt", help="Path to the file where captions will be saved.")
98
+ parser.add_argument("--ram_model", type=str, default="/home/erbachj/scratch2/projects/var_post_samp/thirdparty/SeeSR/preset/model/ram_swin_large_14m.pth", help="Path to the pretrained RAM model.")
99
+ parser.add_argument("--dape_model", type=str, default="/home/erbachj/scratch2/projects/var_post_samp/thirdparty/SeeSR/preset/model/DAPE.pth", help="Path to the pretrained DAPE model.")
100
+
101
+ args = parser.parse_args()
102
+
103
+ # Load the RAM model once
104
+ tag_model = load_ram_model(args.ram_model, args.dape_model)
105
+
106
+ # Process images in the directory
107
+ process_images_in_directory(args.input_dir, args.output_file, tag_model)
setup.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(
4
+ name="flair",
5
+ version="0.1.0",
6
+ author="Julius Erbach",
7
+ author_email="[email protected]",
8
+ description="Solving Inverse Problems with FLAIR",
9
+ package_dir={"": "src"},
10
+ packages=find_packages(where="src"),
11
+ include_package_data=True,
12
+ python_requires=">=3.7",
13
+ keywords="pytorch variational posterior sampling deep learning",
14
+ )
src/flair/__init__.py ADDED
File without changes
src/flair/degradations.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from munch import munchify
4
+ from scipy.ndimage import distance_transform_edt
5
+ from flair.functions.degradation import get_degradation
6
+ import torchvision
7
+
8
+ class BaseDegradation(torch.nn.Module):
9
+ def __init__(self, noise_std=0.0):
10
+ super().__init__()
11
+ self.noise_std = noise_std
12
+
13
+ def forward(self, x):
14
+ x = x + self.noise_std * torch.randn_like(x)
15
+ return x
16
+
17
+ def pseudo_inv(self, y):
18
+ return y
19
+
20
+
21
+ def zero_filler(x, scale):
22
+ B, C, H, W = x.shape
23
+ scale = int(scale)
24
+ H_new, W_new = H * scale, W * scale
25
+ out = torch.zeros(B, C, H_new, W_new, dtype=x.dtype, device=x.device)
26
+ out[:, :, ::scale, ::scale] = x
27
+ return out
28
+
29
+
30
+ class SuperRes(BaseDegradation):
31
+ def __init__(self, scale, noise_std=0.0, img_size=256):
32
+ super().__init__(noise_std=noise_std)
33
+ self.scale = scale
34
+ deg_config = munchify({
35
+ 'channels': 3,
36
+ 'image_size': img_size,
37
+ 'deg_scale': scale
38
+ })
39
+ self.img_size = img_size
40
+ self.deg = get_degradation("sr_bicubic", deg_config, device="cuda")
41
+
42
+ def forward(self, x, noise=True):
43
+ dtype = x.dtype
44
+ y = self.deg.A(x.float())
45
+ # add noise
46
+ if noise:
47
+ y = super().forward(y)
48
+
49
+ return y.to(dtype)
50
+
51
+ def pseudo_inv(self, y):
52
+ x = self.deg.At(y.float()).reshape(1,3,self.img_size, self.img_size)* self.scale**2
53
+ return x.to(y.dtype)
54
+
55
+
56
+ def nn(self, y):
57
+ x = torch.nn.functional.interpolate(
58
+ y.reshape(1,3,self.img_size//self.scale, self.img_size//self.scale), scale_factor=self.scale, mode="nearest"
59
+ )
60
+ return x.to(y.dtype)
61
+
62
+ class SuperResGradio(BaseDegradation):
63
+ def __init__(self, scale, noise_std=0.0, img_size=256):
64
+ super().__init__(noise_std=noise_std)
65
+ self.scale = scale
66
+ self.downscaler = lambda x: torch.nn.functional.interpolate(
67
+ x.float(), scale_factor=1/self.scale, mode="bilinear", align_corners=False, antialias=True
68
+ )
69
+ self.upscaler = lambda x: torch.nn.functional.interpolate(
70
+ x.float(), scale_factor=self.scale, mode="bilinear", align_corners=False, antialias=True
71
+ )
72
+ self.img_size = img_size
73
+
74
+ def forward(self, x, noise=True):
75
+ dtype = x.dtype
76
+ y = self.downscaler(x.float())
77
+ # add noise
78
+ if noise:
79
+ y = super().forward(y)
80
+
81
+ return y.to(dtype)
82
+
83
+ def pseudo_inv(self, y):
84
+ x = self.upscaler(y.float())
85
+ return x.to(y.dtype)
86
+
87
+
88
+ def nn(self, y):
89
+ x = torch.nn.functional.interpolate(
90
+ y.reshape(1,3,self.img_size//self.scale, self.img_size//self.scale), scale_factor=self.scale, mode="nearest-exact"
91
+ )
92
+ return x.to(y.dtype)
93
+
94
+
95
+ class Inpainting(BaseDegradation):
96
+ def __init__(self, mask, H, W, noise_std=0.0):
97
+ """
98
+ mask: torch.Tensor, shape (H, W), dtype bool
99
+ function assumes 3 channels
100
+ """
101
+ super().__init__(noise_std=noise_std)
102
+ if isinstance(mask, list):
103
+ # generate box from left, right, lower upper list
104
+ # observed region is True
105
+ mask_ = torch.ones(H, W, dtype=torch.bool)
106
+ mask_[slice(*mask[0:2]), slice(*mask[2:])] = False
107
+ # repeat for 3 channels
108
+ mask_ = mask_.repeat(3, 1, 1)
109
+ elif isinstance(mask, str):
110
+ # load mask file
111
+ mask_ = torch.tensor(np.load(mask), dtype=torch.bool)
112
+ mask_ = mask_.repeat(3, 1, 1)
113
+ elif isinstance(mask, torch.Tensor):
114
+ if mask.ndim == 2:
115
+ # assume mask is for one channel, repeat for 3 channels
116
+ mask_ = mask[None].repeat(3, 1, 1)
117
+ elif mask.ndim == 3 and mask.shape[0] == 1:
118
+ # assume mask is for one channel, repeat for 3 channels
119
+ mask_ = mask.repeat(3, 1, 1)
120
+ else:
121
+ mask_ = mask
122
+ else:
123
+ raise ValueError("Mask must be a list, string (file path), or torch.Tensor.")
124
+ self.mask = mask_
125
+ self.H, self.W = H, W
126
+
127
+ def forward(self, x, noise=True):
128
+ B = x.shape[0]
129
+ y = x[self.mask[None]].view(B, -1)
130
+ # add noise
131
+ if noise:
132
+ y = super().forward(y)
133
+ return y
134
+
135
+ def pseudo_inv(self, y):
136
+ x = torch.zeros(y.shape[0], 3 * self.H * self.W, dtype=y.dtype, device=y.device)
137
+ x[:, self.mask.view(-1)] = y
138
+ x = x.view(y.shape[0], 3, self.H, self.W)
139
+ # x = inpaint_nearest(x[0], self.mask[0])[None]
140
+ return x
141
+
142
+
143
+ def inpaint_nearest(image: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
144
+ """
145
+ Fill missing pixels in an image using the nearest observed pixel value.
146
+
147
+ Args:
148
+ image: A tensor of shape [C, H, W] representing the image.
149
+ mask: A tensor of shape [H, W] with 1 for observed pixels and 0 for missing pixels.
150
+
151
+ Returns:
152
+ A tensor of shape [C, H, W] where missing pixels have been filled.
153
+ """
154
+ # Move tensors to CPU and convert to numpy arrays.
155
+ image_np = image.cpu().float().numpy()
156
+ # Convert mask to boolean: True for observed, False for missing.
157
+ mask_np = mask.cpu().numpy().astype(bool)
158
+
159
+ # Compute the distance transform of the inverse mask (~mask_np).
160
+ # The function returns:
161
+ # - distances: distance to the nearest True pixel in mask_np
162
+ # - indices: the indices of that nearest True pixel for each pixel.
163
+ # indices has shape (2, H, W): first row is the row index, second row is the column index.
164
+ _, indices = distance_transform_edt(~mask_np, return_indices=True)
165
+
166
+ # Create a copy of the image to hold the filled values.
167
+ filled_image_np = np.empty_like(image_np)
168
+
169
+ # For each channel, replace every pixel with the value of the nearest observed pixel.
170
+ for c in range(image_np.shape[0]):
171
+ filled_image_np[c] = image_np[c, indices[0], indices[1]]
172
+
173
+ # Convert back to a torch tensor and send to the original device.
174
+ return torch.from_numpy(filled_image_np).to(image.device).to(image.dtype)
175
+
176
+ class MotionBlur(BaseDegradation):
177
+ def __init__(self, kernel_size=5, img_size=256, noise_std=0.0):
178
+ super().__init__(noise_std=noise_std)
179
+ deg_config = munchify({
180
+ 'channels': 3,
181
+ 'image_size': img_size,
182
+ 'deg_scale': kernel_size
183
+ })
184
+ self.img_size = img_size
185
+ self.deg = get_degradation("deblur_motion", deg_config, device="cuda")
186
+
187
+ def forward(self, x, noise=True):
188
+ dtype = x.dtype
189
+ y = self.deg.A(x.float())
190
+ # add noise
191
+ if noise:
192
+ y = super().forward(y)
193
+ return y.to(dtype)
194
+
195
+ def pseudo_inv(self, y):
196
+ dtype = y.dtype
197
+ x = self.deg.At(y.float()).reshape(1,3,self.img_size, self.img_size)
198
+ return x.to(dtype)
src/flair/functions/__init__.py ADDED
File without changes
src/flair/functions/ckpt_util.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, hashlib
2
+ import requests
3
+ from tqdm import tqdm
4
+
5
+ URL_MAP = {
6
+ "cifar10": "https://heibox.uni-heidelberg.de/f/869980b53bf5416c8a28/?dl=1",
7
+ "ema_cifar10": "https://heibox.uni-heidelberg.de/f/2e4f01e2d9ee49bab1d5/?dl=1",
8
+ "lsun_bedroom": "https://heibox.uni-heidelberg.de/f/f179d4f21ebc4d43bbfe/?dl=1",
9
+ "ema_lsun_bedroom": "https://heibox.uni-heidelberg.de/f/b95206528f384185889b/?dl=1",
10
+ "lsun_cat": "https://heibox.uni-heidelberg.de/f/fac870bd988348eab88e/?dl=1",
11
+ "ema_lsun_cat": "https://heibox.uni-heidelberg.de/f/0701aac3aa69457bbe34/?dl=1",
12
+ "lsun_church": "https://heibox.uni-heidelberg.de/f/2711a6f712e34b06b9d8/?dl=1",
13
+ "ema_lsun_church": "https://heibox.uni-heidelberg.de/f/44ccb50ef3c6436db52e/?dl=1",
14
+ }
15
+ CKPT_MAP = {
16
+ "cifar10": "diffusion_cifar10_model/model-790000.ckpt",
17
+ "ema_cifar10": "ema_diffusion_cifar10_model/model-790000.ckpt",
18
+ "lsun_bedroom": "diffusion_lsun_bedroom_model/model-2388000.ckpt",
19
+ "ema_lsun_bedroom": "ema_diffusion_lsun_bedroom_model/model-2388000.ckpt",
20
+ "lsun_cat": "diffusion_lsun_cat_model/model-1761000.ckpt",
21
+ "ema_lsun_cat": "ema_diffusion_lsun_cat_model/model-1761000.ckpt",
22
+ "lsun_church": "diffusion_lsun_church_model/model-4432000.ckpt",
23
+ "ema_lsun_church": "ema_diffusion_lsun_church_model/model-4432000.ckpt",
24
+ }
25
+ MD5_MAP = {
26
+ "cifar10": "82ed3067fd1002f5cf4c339fb80c4669",
27
+ "ema_cifar10": "1fa350b952534ae442b1d5235cce5cd3",
28
+ "lsun_bedroom": "f70280ac0e08b8e696f42cb8e948ff1c",
29
+ "ema_lsun_bedroom": "1921fa46b66a3665e450e42f36c2720f",
30
+ "lsun_cat": "bbee0e7c3d7abfb6e2539eaf2fb9987b",
31
+ "ema_lsun_cat": "646f23f4821f2459b8bafc57fd824558",
32
+ "lsun_church": "eb619b8a5ab95ef80f94ce8a5488dae3",
33
+ "ema_lsun_church": "fdc68a23938c2397caba4a260bc2445f",
34
+ }
35
+
36
+
37
+ def download(url, local_path, chunk_size=1024):
38
+ os.makedirs(os.path.split(local_path)[0], exist_ok=True)
39
+ with requests.get(url, stream=True) as r:
40
+ total_size = int(r.headers.get("content-length", 0))
41
+ with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
42
+ with open(local_path, "wb") as f:
43
+ for data in r.iter_content(chunk_size=chunk_size):
44
+ if data:
45
+ f.write(data)
46
+ pbar.update(chunk_size)
47
+
48
+
49
+ def md5_hash(path):
50
+ with open(path, "rb") as f:
51
+ content = f.read()
52
+ return hashlib.md5(content).hexdigest()
53
+
54
+
55
+ def get_ckpt_path(name, root=None, check=False, prefix='exp'):
56
+ if 'church_outdoor' in name:
57
+ name = name.replace('church_outdoor', 'church')
58
+ assert name in URL_MAP
59
+ # Modify the path when necessary
60
+ cachedir = os.environ.get("XDG_CACHE_HOME", os.path.join(prefix, "logs/"))
61
+ root = (
62
+ root
63
+ if root is not None
64
+ else os.path.join(cachedir, "diffusion_models_converted")
65
+ )
66
+ path = os.path.join(root, CKPT_MAP[name])
67
+ if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
68
+ print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
69
+ download(URL_MAP[name], path)
70
+ md5 = md5_hash(path)
71
+ assert md5 == MD5_MAP[name], md5
72
+ return path
src/flair/functions/conjugate_gradient.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # From wikipedia. MATLAB code,
2
+ # function x = conjgrad(A, b, x)
3
+ # r = b - A * x;
4
+ # p = r;
5
+ # rsold = r' * r;
6
+ #
7
+ # for i = 1:length(b)
8
+ # Ap = A * p;
9
+ # alpha = rsold / (p' * Ap);
10
+ # x = x + alpha * p;
11
+ # r = r - alpha * Ap;
12
+ # rsnew = r' * r;
13
+ # if sqrt(rsnew) < 1e-10
14
+ # break
15
+ # end
16
+ # p = r + (rsnew / rsold) * p;
17
+ # rsold = rsnew;
18
+ # end
19
+ # end
20
+
21
+ from typing import Callable, Optional
22
+
23
+ import torch
24
+
25
+
26
+ def CG(A: Callable,
27
+ b: torch.Tensor,
28
+ x: torch.Tensor,
29
+ m: Optional[int]=5,
30
+ eps: Optional[float]=1e-4,
31
+ damping: float=0.0,
32
+ use_mm: bool=False) -> torch.Tensor:
33
+
34
+ if use_mm:
35
+ mm_fn = lambda x, y: torch.mm(x.view(1, -1), y.view(1, -1).T)
36
+ else:
37
+ mm_fn = lambda x, y: (x * y).flatten().sum()
38
+
39
+ orig_shape = x.shape
40
+ x = x.view(x.shape[0], -1)
41
+
42
+ r = b - A(x)
43
+ p = r.clone()
44
+
45
+ rsold = mm_fn(r, r)
46
+ assert not (rsold != rsold).any(), print(f'NaN detected 1')
47
+
48
+ for i in range(m):
49
+ Ap = A(p) + damping * p
50
+ alpha = rsold / mm_fn(p, Ap)
51
+ assert not (alpha != alpha).any(), print(f'NaN detected 2')
52
+
53
+ x = x + alpha * p
54
+ r = r - alpha * Ap
55
+
56
+ rsnew = mm_fn(r, r)
57
+ assert not (rsnew != rsnew).any(), print('NaN detected 3')
58
+
59
+ if rsnew.sqrt().abs() < eps:
60
+ break
61
+
62
+ p = r + (rsnew / rsold) * p
63
+ rsold = rsnew.clone()
64
+
65
+ return x.reshape(orig_shape)
66
+
src/flair/functions/degradation.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from munch import Munch
4
+
5
+ from flair.functions import svd_operators as svd_op
6
+ from flair.functions import measurements
7
+ from flair.utils.inpaint_util import MaskGenerator
8
+
9
+ __DEGRADATION__ = {}
10
+
11
+ def register_degradation(name: str):
12
+ def wrapper(fn):
13
+ if __DEGRADATION__.get(name) is not None:
14
+ raise NameError(f'DEGRADATION {name} is already registered')
15
+ __DEGRADATION__[name]=fn
16
+ return fn
17
+ return wrapper
18
+
19
+ def get_degradation(name: str,
20
+ deg_config: Munch,
21
+ device:torch.device):
22
+ if __DEGRADATION__.get(name) is None:
23
+ raise NameError(f'DEGRADATION {name} does not exist.')
24
+ return __DEGRADATION__[name](deg_config, device)
25
+
26
+ @register_degradation(name='cs_walshhadamard')
27
+ def deg_cs_walshhadamard(deg_config, device):
28
+ compressed_size = round(1/deg_config.deg_scale)
29
+ A_funcs = svd_op.WalshHadamardCS(deg_config.channels,
30
+ deg_config.image_size,
31
+ compressed_size,
32
+ torch.randperm(deg_config.image_size**2),
33
+ device)
34
+ return A_funcs
35
+
36
+ @register_degradation(name='cs_blockbased')
37
+ def deg_cs_blockbased(deg_config, device):
38
+ cs_ratio = deg_config.deg_scale
39
+ A_funcs = svd_op.CS(deg_config.channels,
40
+ deg_config.image_size,
41
+ cs_ratio,
42
+ device)
43
+ return A_funcs
44
+
45
+ @register_degradation(name='inpainting')
46
+ def deg_inpainting(deg_config, device):
47
+ # TODO: generate mask rather than load
48
+ loaded = np.load("exp/inp_masks/mask_768_half.npy") # block
49
+ # loaded = np.load("lip_mask_4.npy")
50
+ mask = torch.from_numpy(loaded).to(device).reshape(-1)
51
+ missing_r = torch.nonzero(mask == 0).long().reshape(-1) * 3
52
+ missing_g = missing_r + 1
53
+ missing_b = missing_g + 1
54
+ missing = torch.cat([missing_r, missing_g, missing_b], dim=0)
55
+ A_funcs = svd_op.Inpainting(deg_config.channels,
56
+ deg_config.image_size,
57
+ missing,
58
+ device)
59
+ return A_funcs
60
+
61
+ @register_degradation(name='denoising')
62
+ def deg_denoise(deg_config, device):
63
+ A_funcs = svd_op.Denoising(deg_config.channels,
64
+ deg_config.image_size,
65
+ device)
66
+ return A_funcs
67
+
68
+ @register_degradation(name='colorization')
69
+ def deg_colorization(deg_config, device):
70
+ A_funcs = svd_op.Colorization(deg_config.image_size,
71
+ device)
72
+ return A_funcs
73
+
74
+
75
+ @register_degradation(name='sr_avgpool')
76
+ def deg_sr_avgpool(deg_config, device):
77
+ blur_by = int(deg_config.deg_scale)
78
+ A_funcs = svd_op.SuperResolution(deg_config.channels,
79
+ deg_config.image_size,
80
+ blur_by,
81
+ device)
82
+ return A_funcs
83
+
84
+ @register_degradation(name='sr_bicubic')
85
+ def deg_sr_bicubic(deg_config, device):
86
+ def bicubic_kernel(x, a=-0.5):
87
+ if abs(x) <= 1:
88
+ return (a + 2) * abs(x) ** 3 - (a + 3) * abs(x) ** 2 + 1
89
+ elif 1 < abs(x) and abs(x) < 2:
90
+ return a * abs(x) ** 3 - 5 * a * abs(x) ** 2 + 8 * a * abs(x) - 4 * a
91
+ else:
92
+ return 0
93
+
94
+ factor = int(deg_config.deg_scale)
95
+ k = np.zeros((factor * 4))
96
+ for i in range(factor * 4):
97
+ x = (1 / factor) * (i - np.floor(factor * 4 / 2) + 0.5)
98
+ k[i] = bicubic_kernel(x)
99
+ k = k / np.sum(k)
100
+ kernel = torch.from_numpy(k).float().to(device)
101
+ A_funcs = svd_op.SRConv(kernel / kernel.sum(),
102
+ deg_config.channels,
103
+ deg_config.image_size,
104
+ device,
105
+ stride=factor)
106
+ return A_funcs
107
+
108
+ @register_degradation(name='deblur_uni')
109
+ def deg_deblur_uni(deg_config, device):
110
+ A_funcs = svd_op.Deblurring(torch.tensor([1/deg_config.deg_scale]*deg_config.deg_scale).to(device),
111
+ deg_config.channels,
112
+ deg_config.image_size,
113
+ device)
114
+ return A_funcs
115
+
116
+ @register_degradation(name='deblur_gauss')
117
+ def deg_deblur_gauss(deg_config, device):
118
+ sigma = 3.0
119
+ pdf = lambda x: torch.exp(torch.Tensor([-0.5 * (x / sigma) ** 2]))
120
+ size = deg_config.deg_scale
121
+ ker = []
122
+ for k in range(-size//2, size//2):
123
+ ker.append(pdf(k))
124
+ kernel = torch.Tensor(ker).to(device)
125
+ A_funcs = svd_op.Deblurring(kernel / kernel.sum(),
126
+ deg_config.channels,
127
+ deg_config.image_size,
128
+ device)
129
+ return A_funcs
130
+
131
+ @register_degradation(name='deblur_aniso')
132
+ def deg_deblur_aniso(deg_config, device):
133
+ sigma = 20
134
+ pdf = lambda x: torch.exp(torch.Tensor([-0.5 * (x / sigma) ** 2]))
135
+ kernel2 = torch.Tensor([pdf(-4), pdf(-3), pdf(-2), pdf(-1), pdf(0), pdf(1), pdf(2), pdf(3), pdf(4)]).to(device)
136
+
137
+ sigma = 1
138
+ pdf = lambda x: torch.exp(torch.Tensor([-0.5 * (x / sigma) ** 2]))
139
+ kernel1 = torch.Tensor([pdf(-4), pdf(-3), pdf(-2), pdf(-1), pdf(0), pdf(1), pdf(2), pdf(3), pdf(4)]).to(device)
140
+
141
+ A_funcs = svd_op.Deblurring2D(kernel1 / kernel1.sum(),
142
+ kernel2 / kernel2.sum(),
143
+ deg_config.channels,
144
+ deg_config.image_size,
145
+ device)
146
+ return A_funcs
147
+
148
+ @register_degradation(name='deblur_motion')
149
+ def deg_deblur_motion(deg_config, device):
150
+ A_funcs = measurements.MotionBlurOperator(
151
+ kernel_size=deg_config.deg_scale,
152
+ intensity=0.5,
153
+ device=device
154
+ )
155
+ return A_funcs
156
+
157
+ @register_degradation(name='deblur_nonuniform')
158
+ def deg_deblur_motion(deg_config, device, kernels=None, masks=None):
159
+ A_funcs = measurements.NonuniformBlurOperator(
160
+ deg_config.image_size,
161
+ deg_config.deg_scale,
162
+ device,
163
+ kernels=kernels,
164
+ masks=masks,
165
+ )
166
+ return A_funcs
167
+
168
+
169
+ # ======= FOR arbitraty image size =======
170
+ @register_degradation(name='sr_avgpool_gen')
171
+ def deg_sr_avgpool_general(deg_config, device):
172
+ blur_by = int(deg_config.deg_scale)
173
+ A_funcs = svd_op.SuperResolutionGeneral(deg_config.channels,
174
+ deg_config.imgH,
175
+ deg_config.imgW,
176
+ blur_by,
177
+ device)
178
+ return A_funcs
179
+
180
+ @register_degradation(name='deblur_gauss_gen')
181
+ def deg_deblur_guass_general(deg_config, device):
182
+ A_funcs = measurements.GaussialBlurOperator(
183
+ kernel_size=deg_config.deg_scale,
184
+ intensity=3.0,
185
+ device=device
186
+ )
187
+ return A_funcs
188
+
189
+
190
+ from flair.functions.jpeg import jpeg_encode, jpeg_decode
191
+
192
+ class JPEGOperator():
193
+ def __init__(self, qf: int, device):
194
+ self.qf = qf
195
+ self.device = device
196
+
197
+ def A(self, img):
198
+ x_luma, x_chroma = jpeg_encode(img, self.qf)
199
+ return x_luma, x_chroma
200
+
201
+ def At(self, encoded):
202
+ return jpeg_decode(encoded, self.qf)
203
+
204
+
205
+ @register_degradation(name='jpeg')
206
+ def deg_jpeg(deg_config, device):
207
+ A_funcs = JPEGOperator(
208
+ qf = deg_config.deg_scale,
209
+ device=device
210
+ )
211
+ return A_funcs
src/flair/functions/jpeg.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ---------------------------------------------------------------
2
+ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # This file has been modified from ddrm-jpeg.
5
+ #
6
+ # Source:
7
+ # https://github.com/bahjat-kawar/ddrm-jpeg/blob/master/functions/jpeg_torch.py
8
+ #
9
+ # The license for the original version of this file can be
10
+ # found in this directory (LICENSE_DDRM_JPEG).
11
+ # The modifications to this file are subject to the same license.
12
+ # ---------------------------------------------------------------
13
+
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn as nn
17
+
18
+ def dct1(x):
19
+ """
20
+ Discrete Cosine Transform, Type I
21
+ :param x: the input signal
22
+ :return: the DCT-I of the signal over the last dimension
23
+ """
24
+ x_shape = x.shape
25
+ x = x.view(-1, x_shape[-1])
26
+
27
+ return torch.fft.rfft(torch.cat([x, x.flip([1])[:, 1:-1]], dim=1))[:, :, 0].view(*x_shape)
28
+
29
+
30
+ def idct1(X):
31
+ """
32
+ The inverse of DCT-I, which is just a scaled DCT-I
33
+ Our definition if idct1 is such that idct1(dct1(x)) == x
34
+ :param X: the input signal
35
+ :return: the inverse DCT-I of the signal over the last dimension
36
+ """
37
+ n = X.shape[-1]
38
+ return dct1(X) / (2 * (n - 1))
39
+
40
+
41
+ def dct(x, norm=None):
42
+ """
43
+ Discrete Cosine Transform, Type II (a.k.a. the DCT)
44
+ For the meaning of the parameter `norm`, see:
45
+ https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
46
+ :param x: the input signal
47
+ :param norm: the normalization, None or 'ortho'
48
+ :return: the DCT-II of the signal over the last dimension
49
+ """
50
+ x_shape = x.shape
51
+ N = x_shape[-1]
52
+ x = x.contiguous().view(-1, N)
53
+
54
+ v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1)
55
+
56
+ Vc = torch.view_as_real(torch.fft.fft(v, dim=1))
57
+
58
+ k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N)
59
+ W_r = torch.cos(k)
60
+ W_i = torch.sin(k)
61
+
62
+ V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i
63
+
64
+ if norm == 'ortho':
65
+ V[:, 0] /= np.sqrt(N) * 2
66
+ V[:, 1:] /= np.sqrt(N / 2) * 2
67
+
68
+ V = 2 * V.view(*x_shape)
69
+
70
+ return V
71
+
72
+
73
+ def idct(X, norm=None):
74
+ """
75
+ The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III
76
+ Our definition of idct is that idct(dct(x)) == x
77
+ For the meaning of the parameter `norm`, see:
78
+ https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
79
+ :param X: the input signal
80
+ :param norm: the normalization, None or 'ortho'
81
+ :return: the inverse DCT-II of the signal over the last dimension
82
+ """
83
+
84
+ x_shape = X.shape
85
+ N = x_shape[-1]
86
+
87
+ X_v = X.contiguous().view(-1, x_shape[-1]) / 2
88
+
89
+ if norm == 'ortho':
90
+ X_v[:, 0] *= np.sqrt(N) * 2
91
+ X_v[:, 1:] *= np.sqrt(N / 2) * 2
92
+
93
+ k = torch.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :] * np.pi / (2 * N)
94
+ W_r = torch.cos(k)
95
+ W_i = torch.sin(k)
96
+
97
+ V_t_r = X_v
98
+ V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1)
99
+
100
+ V_r = V_t_r * W_r - V_t_i * W_i
101
+ V_i = V_t_r * W_i + V_t_i * W_r
102
+
103
+ V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2)
104
+
105
+ v = torch.fft.irfft(torch.view_as_complex(V), n=V.shape[1], dim=1)
106
+ x = v.new_zeros(v.shape)
107
+ x[:, ::2] += v[:, :N - (N // 2)]
108
+ x[:, 1::2] += v.flip([1])[:, :N // 2]
109
+
110
+ return x.view(*x_shape)
111
+
112
+
113
+ def dct_2d(x, norm=None):
114
+ """
115
+ 2-dimentional Discrete Cosine Transform, Type II (a.k.a. the DCT)
116
+ For the meaning of the parameter `norm`, see:
117
+ https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
118
+ :param x: the input signal
119
+ :param norm: the normalization, None or 'ortho'
120
+ :return: the DCT-II of the signal over the last 2 dimensions
121
+ """
122
+ X1 = dct(x, norm=norm)
123
+ X2 = dct(X1.transpose(-1, -2), norm=norm)
124
+ return X2.transpose(-1, -2)
125
+
126
+
127
+ def idct_2d(X, norm=None):
128
+ """
129
+ The inverse to 2D DCT-II, which is a scaled Discrete Cosine Transform, Type III
130
+ Our definition of idct is that idct_2d(dct_2d(x)) == x
131
+ For the meaning of the parameter `norm`, see:
132
+ https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
133
+ :param X: the input signal
134
+ :param norm: the normalization, None or 'ortho'
135
+ :return: the DCT-II of the signal over the last 2 dimensions
136
+ """
137
+ x1 = idct(X, norm=norm)
138
+ x2 = idct(x1.transpose(-1, -2), norm=norm)
139
+ return x2.transpose(-1, -2)
140
+
141
+
142
+ def dct_3d(x, norm=None):
143
+ """
144
+ 3-dimentional Discrete Cosine Transform, Type II (a.k.a. the DCT)
145
+ For the meaning of the parameter `norm`, see:
146
+ https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
147
+ :param x: the input signal
148
+ :param norm: the normalization, None or 'ortho'
149
+ :return: the DCT-II of the signal over the last 3 dimensions
150
+ """
151
+ X1 = dct(x, norm=norm)
152
+ X2 = dct(X1.transpose(-1, -2), norm=norm)
153
+ X3 = dct(X2.transpose(-1, -3), norm=norm)
154
+ return X3.transpose(-1, -3).transpose(-1, -2)
155
+
156
+
157
+ def idct_3d(X, norm=None):
158
+ """
159
+ The inverse to 3D DCT-II, which is a scaled Discrete Cosine Transform, Type III
160
+ Our definition of idct is that idct_3d(dct_3d(x)) == x
161
+ For the meaning of the parameter `norm`, see:
162
+ https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
163
+ :param X: the input signal
164
+ :param norm: the normalization, None or 'ortho'
165
+ :return: the DCT-II of the signal over the last 3 dimensions
166
+ """
167
+ x1 = idct(X, norm=norm)
168
+ x2 = idct(x1.transpose(-1, -2), norm=norm)
169
+ x3 = idct(x2.transpose(-1, -3), norm=norm)
170
+ return x3.transpose(-1, -3).transpose(-1, -2)
171
+
172
+
173
+ class LinearDCT(nn.Linear):
174
+ """Implement any DCT as a linear layer; in practice this executes around
175
+ 50x faster on GPU. Unfortunately, the DCT matrix is stored, which will
176
+ increase memory usage.
177
+ :param in_features: size of expected input
178
+ :param type: which dct function in this file to use"""
179
+ def __init__(self, in_features, type, norm=None, bias=False):
180
+ self.type = type
181
+ self.N = in_features
182
+ self.norm = norm
183
+ super(LinearDCT, self).__init__(in_features, in_features, bias=bias)
184
+
185
+ def reset_parameters(self):
186
+ # initialise using dct function
187
+ I = torch.eye(self.N)
188
+ if self.type == 'dct1':
189
+ self.weight.data = dct1(I).data.t()
190
+ elif self.type == 'idct1':
191
+ self.weight.data = idct1(I).data.t()
192
+ elif self.type == 'dct':
193
+ self.weight.data = dct(I, norm=self.norm).data.t()
194
+ elif self.type == 'idct':
195
+ self.weight.data = idct(I, norm=self.norm).data.t()
196
+ self.weight.requires_grad = False # don't learn this!
197
+
198
+
199
+ def apply_linear_2d(x, linear_layer):
200
+ """Can be used with a LinearDCT layer to do a 2D DCT.
201
+ :param x: the input signal
202
+ :param linear_layer: any PyTorch Linear layer
203
+ :return: result of linear layer applied to last 2 dimensions
204
+ """
205
+ X1 = linear_layer(x)
206
+ X2 = linear_layer(X1.transpose(-1, -2))
207
+ return X2.transpose(-1, -2)
208
+
209
+
210
+ def apply_linear_3d(x, linear_layer):
211
+ """Can be used with a LinearDCT layer to do a 3D DCT.
212
+ :param x: the input signal
213
+ :param linear_layer: any PyTorch Linear layer
214
+ :return: result of linear layer applied to last 3 dimensions
215
+ """
216
+ X1 = linear_layer(x)
217
+ X2 = linear_layer(X1.transpose(-1, -2))
218
+ X3 = linear_layer(X2.transpose(-1, -3))
219
+ return X3.transpose(-1, -3).transpose(-1, -2)
220
+
221
+
222
+ def torch_rgb2ycbcr(x):
223
+ # Assume x is a batch of size (N x C x H x W)
224
+ v = torch.tensor([[.299, .587, .114], [-.1687, -.3313, .5], [.5, -.4187, -.0813]]).to(x.device)
225
+ ycbcr = torch.tensordot(x, v, dims=([1], [1])).transpose(3, 2).transpose(2, 1)
226
+ ycbcr[:,1:] += 128
227
+ return ycbcr
228
+
229
+
230
+ def torch_ycbcr2rgb(x):
231
+ # Assume x is a batch of size (N x C x H x W)
232
+ v = torch.tensor([[ 1.00000000e+00, -3.68199903e-05, 1.40198758e+00],
233
+ [ 1.00000000e+00, -3.44113281e-01, -7.14103821e-01],
234
+ [ 1.00000000e+00, 1.77197812e+00, -1.34583413e-04]]).to(x.device)
235
+ x[:, 1:] -= 128
236
+ rgb = torch.tensordot(x, v, dims=([1], [1])).transpose(3, 2).transpose(2, 1)
237
+ return rgb
238
+
239
+ def chroma_subsample(x):
240
+ return x[:, 0:1, :, :], x[:, 1:, ::2, ::2]
241
+
242
+
243
+ def general_quant_matrix(qf = 10):
244
+ q1 = torch.tensor([
245
+ 16, 11, 10, 16, 24, 40, 51, 61,
246
+ 12, 12, 14, 19, 26, 58, 60, 55,
247
+ 14, 13, 16, 24, 40, 57, 69, 56,
248
+ 14, 17, 22, 29, 51, 87, 80, 62,
249
+ 18, 22, 37, 56, 68, 109, 103, 77,
250
+ 24, 35, 55, 64, 81, 104, 113, 92,
251
+ 49, 64, 78, 87, 103, 121, 120, 101,
252
+ 72, 92, 95, 98, 112, 100, 103, 99
253
+ ])
254
+ q2 = torch.tensor([
255
+ 17, 18, 24, 47, 99, 99, 99, 99,
256
+ 18, 21, 26, 66, 99, 99, 99, 99,
257
+ 24, 26, 56, 99, 99, 99, 99, 99,
258
+ 47, 66, 99, 99, 99, 99, 99, 99,
259
+ 99, 99, 99, 99, 99, 99, 99, 99,
260
+ 99, 99, 99, 99, 99, 99, 99, 99,
261
+ 99, 99, 99, 99, 99, 99, 99, 99,
262
+ 99, 99, 99, 99, 99, 99, 99, 99
263
+ ])
264
+ s = (5000 / qf) if qf < 50 else (200 - 2 * qf)
265
+ q1 = torch.floor((s * q1 + 50) / 100)
266
+ q1[q1 <= 0] = 1
267
+ q1[q1 > 255] = 255
268
+ q2 = torch.floor((s * q2 + 50) / 100)
269
+ q2[q2 <= 0] = 1
270
+ q2[q2 > 255] = 255
271
+ return q1, q2
272
+
273
+
274
+ def quantization_matrix(qf):
275
+ return general_quant_matrix(qf)
276
+ # q1 = torch.tensor([[ 80, 55, 50, 80, 120, 200, 255, 255],
277
+ # [ 60, 60, 70, 95, 130, 255, 255, 255],
278
+ # [ 70, 65, 80, 120, 200, 255, 255, 255],
279
+ # [ 70, 85, 110, 145, 255, 255, 255, 255],
280
+ # [ 90, 110, 185, 255, 255, 255, 255, 255],
281
+ # [120, 175, 255, 255, 255, 255, 255, 255],
282
+ # [245, 255, 255, 255, 255, 255, 255, 255],
283
+ # [255, 255, 255, 255, 255, 255, 255, 255]])
284
+ # q2 = torch.tensor([[ 85, 90, 120, 235, 255, 255, 255, 255],
285
+ # [ 90, 105, 130, 255, 255, 255, 255, 255],
286
+ # [120, 130, 255, 255, 255, 255, 255, 255],
287
+ # [235, 255, 255, 255, 255, 255, 255, 255],
288
+ # [255, 255, 255, 255, 255, 255, 255, 255],
289
+ # [255, 255, 255, 255, 255, 255, 255, 255],
290
+ # [255, 255, 255, 255, 255, 255, 255, 255],
291
+ # [255, 255, 255, 255, 255, 255, 255, 255]])
292
+ # return q1, q2
293
+
294
+ def jpeg_encode(x, qf):
295
+ # Assume x is a batch of size (N x C x H x W)
296
+ # [-1, 1] to [0, 255]
297
+ x = (x + 1) / 2 * 255
298
+ n_batch, _, n_size, _ = x.shape
299
+
300
+ x = torch_rgb2ycbcr(x)
301
+ x_luma, x_chroma = chroma_subsample(x)
302
+ unfold = nn.Unfold(kernel_size=(8, 8), stride=(8, 8))
303
+ x_luma = unfold(x_luma).transpose(2, 1)
304
+ x_chroma = unfold(x_chroma).transpose(2, 1)
305
+
306
+ x_luma = x_luma.reshape(-1, 8, 8) - 128
307
+ x_chroma = x_chroma.reshape(-1, 8, 8) - 128
308
+
309
+ dct_layer = LinearDCT(8, 'dct', norm='ortho')
310
+ dct_layer.to(x_luma.device)
311
+ x_luma = apply_linear_2d(x_luma, dct_layer)
312
+ x_chroma = apply_linear_2d(x_chroma, dct_layer)
313
+
314
+ x_luma = x_luma.view(-1, 1, 8, 8)
315
+ x_chroma = x_chroma.view(-1, 2, 8, 8)
316
+
317
+ q1, q2 = quantization_matrix(qf)
318
+ q1 = q1.to(x_luma.device)
319
+ q2 = q2.to(x_luma.device)
320
+ x_luma /= q1.view(1, 8, 8)
321
+ x_chroma /= q2.view(1, 8, 8)
322
+
323
+ x_luma = x_luma.round()
324
+ x_chroma = x_chroma.round()
325
+
326
+ x_luma = x_luma.reshape(n_batch, (n_size // 8) ** 2, 64).transpose(2, 1)
327
+ x_chroma = x_chroma.reshape(n_batch, (n_size // 16) ** 2, 64 * 2).transpose(2, 1)
328
+
329
+ fold = nn.Fold(output_size=(n_size, n_size), kernel_size=(8, 8), stride=(8, 8))
330
+ x_luma = fold(x_luma)
331
+ fold = nn.Fold(output_size=(n_size // 2, n_size // 2), kernel_size=(8, 8), stride=(8, 8))
332
+ x_chroma = fold(x_chroma)
333
+
334
+ return [x_luma, x_chroma]
335
+
336
+
337
+
338
+ def jpeg_decode(x, qf):
339
+ # Assume x[0] is a batch of size (N x 1 x H x W) (luma)
340
+ # Assume x[1:] is a batch of size (N x 2 x H/2 x W/2) (chroma)
341
+ x_luma, x_chroma = x
342
+ n_batch, _, n_size, _ = x_luma.shape
343
+ unfold = nn.Unfold(kernel_size=(8, 8), stride=(8, 8))
344
+ x_luma = unfold(x_luma).transpose(2, 1)
345
+ x_luma = x_luma.reshape(-1, 1, 8, 8)
346
+ x_chroma = unfold(x_chroma).transpose(2, 1)
347
+ x_chroma = x_chroma.reshape(-1, 2, 8, 8)
348
+
349
+ q1, q2 = quantization_matrix(qf)
350
+ q1 = q1.to(x_luma.device)
351
+ q2 = q2.to(x_luma.device)
352
+ x_luma *= q1.view(1, 8, 8)
353
+ x_chroma *= q2.view(1, 8, 8)
354
+
355
+ x_luma = x_luma.reshape(-1, 8, 8)
356
+ x_chroma = x_chroma.reshape(-1, 8, 8)
357
+
358
+ dct_layer = LinearDCT(8, 'idct', norm='ortho')
359
+ dct_layer.to(x_luma.device)
360
+ x_luma = apply_linear_2d(x_luma, dct_layer)
361
+ x_chroma = apply_linear_2d(x_chroma, dct_layer)
362
+
363
+ x_luma = (x_luma + 128).reshape(n_batch, (n_size // 8) ** 2, 64).transpose(2, 1)
364
+ x_chroma = (x_chroma + 128).reshape(n_batch, (n_size // 16) ** 2, 64 * 2).transpose(2, 1)
365
+
366
+ fold = nn.Fold(output_size=(n_size, n_size), kernel_size=(8, 8), stride=(8, 8))
367
+ x_luma = fold(x_luma)
368
+ fold = nn.Fold(output_size=(n_size // 2, n_size // 2), kernel_size=(8, 8), stride=(8, 8))
369
+ x_chroma = fold(x_chroma)
370
+
371
+ x_chroma_repeated = torch.zeros(n_batch, 2, n_size, n_size, device = x_luma.device)
372
+ x_chroma_repeated[:, :, 0::2, 0::2] = x_chroma
373
+ x_chroma_repeated[:, :, 0::2, 1::2] = x_chroma
374
+ x_chroma_repeated[:, :, 1::2, 0::2] = x_chroma
375
+ x_chroma_repeated[:, :, 1::2, 1::2] = x_chroma
376
+
377
+ x = torch.cat([x_luma, x_chroma_repeated], dim=1)
378
+
379
+ x = torch_ycbcr2rgb(x)
380
+
381
+ # [0, 255] to [-1, 1]
382
+ x = x / 255 * 2 - 1
383
+
384
+ return x
385
+
386
+
387
+ def build_jpeg(qf):
388
+ # log.info(f"[Corrupt] JPEG restoration: {qf=} ...")
389
+ def jpeg(img):
390
+ encoded = jpeg_encode(img, qf)
391
+ return jpeg_decode(encoded, qf), encoded
392
+ return jpeg
src/flair/functions/measurements.py ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''This module handles task-dependent operations (A) and noises (n) to simulate a measurement y=Ax+n.'''
2
+
3
+ from abc import ABC, abstractmethod
4
+ from functools import partial
5
+
6
+ from torch.nn import functional as F
7
+ from torchvision import torch
8
+
9
+ from flair.utils.blur_util import Blurkernel
10
+ from flair.utils.img_util import fft2d
11
+ import numpy as np
12
+ from flair.utils.resizer import Resizer
13
+ from flair.utils.utils_sisr import pre_calculate_FK, pre_calculate_nonuniform
14
+ from torch.fft import fft2, ifft2
15
+
16
+
17
+ from flair.motionblur.motionblur import Kernel
18
+
19
+ # =================
20
+ # Operation classes
21
+ # =================
22
+
23
+ __OPERATOR__ = {}
24
+ _GAMMA_FACTOR = 2.2
25
+
26
+ def register_operator(name: str):
27
+ def wrapper(cls):
28
+ if __OPERATOR__.get(name, None):
29
+ raise NameError(f"Name {name} is already registered!")
30
+ __OPERATOR__[name] = cls
31
+ return cls
32
+ return wrapper
33
+
34
+
35
+ def get_operator(name: str, **kwargs):
36
+ if __OPERATOR__.get(name, None) is None:
37
+ raise NameError(f"Name {name} is not defined.")
38
+ return __OPERATOR__[name](**kwargs)
39
+
40
+
41
+ class LinearOperator(ABC):
42
+ @abstractmethod
43
+ def forward(self, data, **kwargs):
44
+ # calculate A * X
45
+ pass
46
+
47
+ @abstractmethod
48
+ def noisy_forward(self, data, **kwargs):
49
+ # calculate A * X + n
50
+ pass
51
+
52
+ @abstractmethod
53
+ def transpose(self, data, **kwargs):
54
+ # calculate A^T * X
55
+ pass
56
+
57
+ def ortho_project(self, data, **kwargs):
58
+ # calculate (I - A^T * A)X
59
+ return data - self.transpose(self.forward(data, **kwargs), **kwargs)
60
+
61
+ def project(self, data, measurement, **kwargs):
62
+ # calculate (I - A^T * A)Y - AX
63
+ return self.ortho_project(measurement, **kwargs) - self.forward(data, **kwargs)
64
+
65
+
66
+ @register_operator(name='noise')
67
+ class DenoiseOperator(LinearOperator):
68
+ def __init__(self, device):
69
+ self.device = device
70
+
71
+ def forward(self, data):
72
+ return data
73
+
74
+ def noisy_forward(self, data):
75
+ return data
76
+
77
+ def transpose(self, data):
78
+ return data
79
+
80
+ def ortho_project(self, data):
81
+ return data
82
+
83
+ def project(self, data):
84
+ return data
85
+
86
+
87
+ @register_operator(name='sr_bicubic')
88
+ class SuperResolutionOperator(LinearOperator):
89
+ def __init__(self,
90
+ in_shape,
91
+ scale_factor,
92
+ noise,
93
+ noise_scale,
94
+ device):
95
+ self.device = device
96
+ self.up_sample = partial(F.interpolate, scale_factor=scale_factor)
97
+ self.down_sample = Resizer(in_shape, 1/scale_factor).to(device)
98
+ self.noise = get_noise(name=noise, scale=noise_scale)
99
+
100
+ def A(self, data, **kwargs):
101
+ return self.forward(data, **kwargs)
102
+
103
+ def forward(self, data, **kwargs):
104
+ return self.down_sample(data)
105
+
106
+ def noisy_forward(self, data, **kwargs):
107
+ return self.noise.forward(self.down_sample(data))
108
+
109
+ def transpose(self, data, **kwargs):
110
+ return self.up_sample(data)
111
+
112
+ def project(self, data, measurement, **kwargs):
113
+ return data - self.transpose(self.forward(data)) + self.transpose(measurement)
114
+
115
+ @register_operator(name='deblur_motion')
116
+ class MotionBlurOperator(LinearOperator):
117
+ def __init__(self,
118
+ kernel_size,
119
+ intensity,
120
+ device):
121
+ self.device = device
122
+ self.kernel_size = kernel_size
123
+ self.conv = Blurkernel(blur_type='motion',
124
+ kernel_size=kernel_size,
125
+ std=intensity,
126
+ device=device).to(device) # should we keep this device term?
127
+
128
+ self.kernel_size =kernel_size
129
+ self.intensity = intensity
130
+ self.kernel = Kernel(size=(kernel_size, kernel_size), intensity=intensity)
131
+ kernel = torch.tensor(self.kernel.kernelMatrix, dtype=torch.float32)
132
+ self.conv.update_weights(kernel)
133
+
134
+ def forward(self, data, **kwargs):
135
+ # A^T * A
136
+ return self.conv(data)
137
+
138
+ def noisy_forward(self, data, **kwargs):
139
+ pass
140
+
141
+ def transpose(self, data, **kwargs):
142
+ return data
143
+
144
+ def change_kernel(self):
145
+ self.kernel = Kernel(size=(self.kernel_size, self.kernel_size), intensity=self.intensity)
146
+ kernel = torch.tensor(self.kernel.kernelMatrix, dtype=torch.float32)
147
+ self.conv.update_weights(kernel)
148
+
149
+ def get_kernel(self):
150
+ kernel = self.kernel.kernelMatrix.type(torch.float32).to(self.device)
151
+ return kernel.view(1, 1, self.kernel_size, self.kernel_size)
152
+
153
+ def A(self, data):
154
+ return self.forward(data)
155
+
156
+ def At(self, data):
157
+ return self.transpose(data)
158
+
159
+ @register_operator(name='deblur_gauss')
160
+ class GaussialBlurOperator(LinearOperator):
161
+ def __init__(self,
162
+ kernel_size,
163
+ intensity,
164
+ device):
165
+ self.device = device
166
+ self.kernel_size = kernel_size
167
+ self.conv = Blurkernel(blur_type='gaussian',
168
+ kernel_size=kernel_size,
169
+ std=intensity,
170
+ device=device).to(device)
171
+ self.kernel = self.conv.get_kernel()
172
+ self.conv.update_weights(self.kernel.type(torch.float32))
173
+
174
+ def forward(self, data, **kwargs):
175
+ return self.conv(data)
176
+
177
+ def noisy_forward(self, data, **kwargs):
178
+ pass
179
+
180
+ def transpose(self, data, **kwargs):
181
+ return data
182
+
183
+ def get_kernel(self):
184
+ return self.kernel.view(1, 1, self.kernel_size, self.kernel_size)
185
+
186
+ def apply_kernel(self, data, kernel):
187
+ self.conv.update_weights(kernel.type(torch.float32))
188
+ return self.conv(data)
189
+
190
+ def A(self, data):
191
+ return self.forward(data)
192
+
193
+ def At(self, data):
194
+ return self.transpose(data)
195
+
196
+ @register_operator(name='inpainting')
197
+ class InpaintingOperator(LinearOperator):
198
+ '''This operator get pre-defined mask and return masked image.'''
199
+ def __init__(self,
200
+ noise,
201
+ noise_scale,
202
+ device):
203
+ self.device = device
204
+ self.noise = get_noise(name=noise, scale=noise_scale)
205
+
206
+ def forward(self, data, **kwargs):
207
+ try:
208
+ return data * kwargs.get('mask', None).to(self.device)
209
+ except:
210
+ raise ValueError("Require mask")
211
+
212
+ def noisy_forward(self, data, **kwargs):
213
+ return self.noise.forward(self.forward(data, **kwargs))
214
+
215
+ def transpose(self, data, **kwargs):
216
+ return data
217
+
218
+ def ortho_project(self, data, **kwargs):
219
+ return data - self.forward(data, **kwargs)
220
+
221
+ # Operator for BlindDPS.
222
+ @register_operator(name='blind_blur')
223
+ class BlindBlurOperator(LinearOperator):
224
+ def __init__(self, device, **kwargs) -> None:
225
+ self.device = device
226
+
227
+ def forward(self, data, kernel, **kwargs):
228
+ return self.apply_kernel(data, kernel)
229
+
230
+ def transpose(self, data, **kwargs):
231
+ return data
232
+
233
+ def apply_kernel(self, data, kernel):
234
+ #TODO: faster way to apply conv?:W
235
+
236
+ b_img = torch.zeros_like(data).to(self.device)
237
+ for i in range(3):
238
+ b_img[:, i, :, :] = F.conv2d(data[:, i:i+1, :, :], kernel, padding='same')
239
+ return b_img
240
+
241
+
242
+ class NonLinearOperator(ABC):
243
+ @abstractmethod
244
+ def forward(self, data, **kwargs):
245
+ pass
246
+
247
+ @abstractmethod
248
+ def noisy_forward(self, data, **kwargs):
249
+ pass
250
+
251
+ def project(self, data, measurement, **kwargs):
252
+ return data + measurement - self.forward(data)
253
+
254
+ @register_operator(name='phase_retrieval')
255
+ class PhaseRetrievalOperator(NonLinearOperator):
256
+ def __init__(self,
257
+ oversample,
258
+ noise,
259
+ noise_scale,
260
+ device):
261
+ self.pad = int((oversample / 8.0) * 256)
262
+ self.device = device
263
+ self.noise = get_noise(name=noise, scale=noise_scale)
264
+
265
+ def forward(self, data, **kwargs):
266
+ padded = F.pad(data, (self.pad, self.pad, self.pad, self.pad))
267
+ amplitude = fft2d(padded).abs()
268
+ return amplitude
269
+
270
+ def noisy_forard(self, data, **kwargs):
271
+ return self.noise.forward(self.forward(data, **kwargs))
272
+
273
+ @register_operator(name='nonuniform_blur')
274
+ class NonuniformBlurOperator(LinearOperator):
275
+ def __init__(self, in_shape, kernel_size, device,
276
+ kernels=None, masks=None):
277
+ self.device = device
278
+ self.kernel_size = kernel_size
279
+ self.in_shape = in_shape
280
+
281
+ # TODO: generalize
282
+ if kernels is None and masks is None:
283
+ self.kernels = np.load('./functions/nonuniform/kernels/000001.npy')
284
+ self.masks = np.load('./functions/nonuniform/masks/000001.npy')
285
+ self.kernels = torch.tensor(self.kernels).to(device)
286
+ self.masks = torch.tensor(self.masks).to(device)
287
+
288
+ # approximate in image space
289
+ def forward_img(self, data):
290
+ K = self.kernel_size
291
+ data = F.pad(data, (K//2, K//2, K//2, K//2), mode="reflect")
292
+ kernels = self.kernels.transpose(0, 1)
293
+ data_rgb_batch = data.transpose(0, 1)
294
+ conv_rgb_batch = F.conv2d(data_rgb_batch, kernels)
295
+ y_rgb_batch = conv_rgb_batch * self.masks
296
+ y_rgb_batch = torch.sum(y_rgb_batch, dim=1, keepdim=True)
297
+ y = y_rgb_batch.transpose(0, 1)
298
+ return y
299
+
300
+ # NOTE: Only using this operator will make the problem nonlinear (gamma-correction)
301
+ def forward_nonlinear(self, data, flatten=False, noiseless=False):
302
+ # 1. Usual nonuniform blurring degradataion pipeline
303
+ kernels = self.kernels.transpose(0, 1)
304
+ FK, FKC = pre_calculate_FK(kernels)
305
+ y = ifft2(FK * fft2(data)).real
306
+ y = y.transpose(0, 1)
307
+ y_rgb_batch = self.masks * y
308
+ y_rgb_batch = torch.sum(y_rgb_batch, dim=1, keepdim=True)
309
+ y = y_rgb_batch.transpose(0, 1)
310
+ F2KM, FKFMy = pre_calculate_nonuniform(data, y, FK, FKC, self.masks)
311
+ self.pre_calculated = (FK, FKC, F2KM, FKFMy)
312
+ # 2. Gamma-correction
313
+ y = (y + 1) / 2
314
+ y = y ** (1 / _GAMMA_FACTOR)
315
+ y = (y - 0.5) / 0.5
316
+ return y
317
+
318
+ def noisy_forward(self, data, **kwargs):
319
+ return self.noise.forward(self.forward(data))
320
+
321
+ # exact in Fourier
322
+ def forward(self, data, flatten=False, noiseless=False):
323
+ # [1, 25, 33, 33] -> [25, 1, 33, 33]
324
+ kernels = self.kernels.transpose(0, 1)
325
+ # [25, 1, 512, 512]
326
+ FK, FKC = pre_calculate_FK(kernels)
327
+ # [25, 3, 512, 512]
328
+ y = ifft2(FK * fft2(data)).real
329
+ # [3, 25, 512, 512]
330
+ y = y.transpose(0, 1)
331
+ y_rgb_batch = self.masks * y
332
+ # [3, 1, 512, 512]
333
+ y_rgb_batch = torch.sum(y_rgb_batch, dim=1, keepdim=True)
334
+ # [1, 3, 512, 512]
335
+ y = y_rgb_batch.transpose(0, 1)
336
+ F2KM, FKFMy = pre_calculate_nonuniform(data, y, FK, FKC, self.masks)
337
+ self.pre_calculated = (FK, FKC, F2KM, FKFMy)
338
+ return y
339
+
340
+ def transpose(self, y, flatten=False):
341
+ kernels = self.kernels.transpose(0, 1)
342
+ FK, FKC = pre_calculate_FK(kernels)
343
+ # 1. braodcast and multiply
344
+ # [3, 1, 512, 512]
345
+ y_rgb_batch = y.transpose(0, 1)
346
+ # [3, 25, 512, 512]
347
+ y_rgb_batch = y_rgb_batch.repeat(1, 25, 1, 1)
348
+ y = self.masks * y_rgb_batch
349
+ # 2. transpose of convolution in Fourier
350
+ # [25, 3, 512, 512]
351
+ y = y.transpose(0, 1)
352
+ ATy_broadcast = ifft2(FKC * fft2(y)).real
353
+ # [1, 3, 512, 512]
354
+ ATy = torch.sum(ATy_broadcast, dim=0, keepdim=True)
355
+ return ATy
356
+
357
+ def A(self, data):
358
+ return self.forward(data)
359
+
360
+ def At(self, data):
361
+ return self.transpose(data)
362
+
363
+ # =============
364
+ # Noise classes
365
+ # =============
366
+
367
+
368
+ __NOISE__ = {}
369
+
370
+ def register_noise(name: str):
371
+ def wrapper(cls):
372
+ if __NOISE__.get(name, None):
373
+ raise NameError(f"Name {name} is already defined!")
374
+ __NOISE__[name] = cls
375
+ return cls
376
+ return wrapper
377
+
378
+ def get_noise(name: str, **kwargs):
379
+ if __NOISE__.get(name, None) is None:
380
+ raise NameError(f"Name {name} is not defined.")
381
+ noiser = __NOISE__[name](**kwargs)
382
+ noiser.__name__ = name
383
+ return noiser
384
+
385
+ class Noise(ABC):
386
+ def __call__(self, data):
387
+ return self.forward(data)
388
+
389
+ @abstractmethod
390
+ def forward(self, data):
391
+ pass
392
+
393
+ @register_noise(name='clean')
394
+ class Clean(Noise):
395
+ def __init__(self, **kwargs):
396
+ pass
397
+
398
+ def forward(self, data):
399
+ return data
400
+
401
+ @register_noise(name='gaussian')
402
+ class GaussianNoise(Noise):
403
+ def __init__(self, scale):
404
+ self.scale = scale
405
+
406
+ def forward(self, data):
407
+ return data + torch.randn_like(data, device=data.device) * self.scale
408
+
409
+
410
+ @register_noise(name='poisson')
411
+ class PoissonNoise(Noise):
412
+ def __init__(self, scale):
413
+ self.scale = scale
414
+
415
+ def forward(self, data):
416
+ '''
417
+ Follow skimage.util.random_noise.
418
+ '''
419
+
420
+ # version 3 (stack-overflow)
421
+ import numpy as np
422
+ data = (data + 1.0) / 2.0
423
+ data = data.clamp(0, 1)
424
+ device = data.device
425
+ data = data.detach().cpu()
426
+ data = torch.from_numpy(np.random.poisson(data * 255.0 * self.scale) / 255.0 / self.scale)
427
+ data = data * 2.0 - 1.0
428
+ data = data.clamp(-1, 1)
429
+ return data.to(device)
src/flair/functions/nonuniform/kernels/000001.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:538b2e852fdfd3966628fbf22b53ed75343c43084608aa12c05c7dbbd0db6728
3
+ size 109028
src/flair/functions/svd_ddnm.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from tqdm import tqdm
3
+ import torchvision.utils as tvu
4
+ import torchvision
5
+ import os
6
+
7
+ class_num = 951
8
+
9
+
10
+ def compute_alpha(beta, t):
11
+ beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0)
12
+ a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1)
13
+ return a
14
+
15
+ def inverse_data_transform(x):
16
+ x = (x + 1.0) / 2.0
17
+ return torch.clamp(x, 0.0, 1.0)
18
+
19
+ def ddnm_diffusion(x, model, b, eta, A_funcs, y, cls_fn=None, classes=None, config=None):
20
+ with torch.no_grad():
21
+
22
+ # setup iteration variables
23
+ skip = config.diffusion.num_diffusion_timesteps//config.time_travel.T_sampling
24
+ n = x.size(0)
25
+ x0_preds = []
26
+ xs = [x]
27
+
28
+ # generate time schedule
29
+ times = get_schedule_jump(config.time_travel.T_sampling,
30
+ config.time_travel.travel_length,
31
+ config.time_travel.travel_repeat,
32
+ )
33
+ time_pairs = list(zip(times[:-1], times[1:]))
34
+
35
+ # reverse diffusion sampling
36
+ for i, j in tqdm(time_pairs):
37
+ i, j = i*skip, j*skip
38
+ if j<0: j=-1
39
+
40
+ if j < i: # normal sampling
41
+ t = (torch.ones(n) * i).to(x.device)
42
+ next_t = (torch.ones(n) * j).to(x.device)
43
+ at = compute_alpha(b, t.long())
44
+ at_next = compute_alpha(b, next_t.long())
45
+ xt = xs[-1].to('cuda')
46
+ if cls_fn == None:
47
+ et = model(xt, t)
48
+ else:
49
+ classes = torch.ones(xt.size(0), dtype=torch.long, device=torch.device("cuda"))*class_num
50
+ et = model(xt, t, classes)
51
+ et = et[:, :3]
52
+ et = et - (1 - at).sqrt()[0, 0, 0, 0] * cls_fn(x, t, classes)
53
+
54
+ if et.size(1) == 6:
55
+ et = et[:, :3]
56
+
57
+ x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
58
+
59
+ x0_t_hat = x0_t - A_funcs.A_pinv(
60
+ A_funcs.A(x0_t.reshape(x0_t.size(0), -1)) - y.reshape(y.size(0), -1)
61
+ ).reshape(*x0_t.size())
62
+
63
+ c1 = (1 - at_next).sqrt() * eta
64
+ c2 = (1 - at_next).sqrt() * ((1 - eta ** 2) ** 0.5)
65
+ xt_next = at_next.sqrt() * x0_t_hat + c1 * torch.randn_like(x0_t) + c2 * et
66
+
67
+ x0_preds.append(x0_t.to('cpu'))
68
+ xs.append(xt_next.to('cpu'))
69
+ else: # time-travel back
70
+ next_t = (torch.ones(n) * j).to(x.device)
71
+ at_next = compute_alpha(b, next_t.long())
72
+ x0_t = x0_preds[-1].to('cuda')
73
+
74
+ xt_next = at_next.sqrt() * x0_t + torch.randn_like(x0_t) * (1 - at_next).sqrt()
75
+
76
+ xs.append(xt_next.to('cpu'))
77
+
78
+ return [xs[-1]], [x0_preds[-1]]
79
+
80
+ def ddnm_plus_diffusion(x, model, b, eta, A_funcs, y, sigma_y, cls_fn=None, classes=None, config=None):
81
+ with torch.no_grad():
82
+
83
+ # setup iteration variables
84
+ skip = config.diffusion.num_diffusion_timesteps//config.time_travel.T_sampling
85
+ n = x.size(0)
86
+ x0_preds = []
87
+ xs = [x]
88
+
89
+ # generate time schedule
90
+ times = get_schedule_jump(config.time_travel.T_sampling,
91
+ config.time_travel.travel_length,
92
+ config.time_travel.travel_repeat,
93
+ )
94
+ time_pairs = list(zip(times[:-1], times[1:]))
95
+
96
+ # reverse diffusion sampling
97
+ for i, j in tqdm(time_pairs):
98
+ i, j = i*skip, j*skip
99
+ if j<0: j=-1
100
+
101
+ if j < i: # normal sampling
102
+ t = (torch.ones(n) * i).to(x.device)
103
+ next_t = (torch.ones(n) * j).to(x.device)
104
+ at = compute_alpha(b, t.long())
105
+ at_next = compute_alpha(b, next_t.long())
106
+ xt = xs[-1].to('cuda')
107
+ if cls_fn == None:
108
+ et = model(xt, t)
109
+ else:
110
+ classes = torch.ones(xt.size(0), dtype=torch.long, device=torch.device("cuda"))*class_num
111
+ et = model(xt, t, classes)
112
+ et = et[:, :3]
113
+ et = et - (1 - at).sqrt()[0, 0, 0, 0] * cls_fn(x, t, classes)
114
+
115
+ if et.size(1) == 6:
116
+ et = et[:, :3]
117
+
118
+ # Eq. 12
119
+ x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
120
+
121
+ sigma_t = (1 - at_next).sqrt()[0, 0, 0, 0]
122
+
123
+ # Eq. 17
124
+ x0_t_hat = x0_t - A_funcs.Lambda(A_funcs.A_pinv(
125
+ A_funcs.A(x0_t.reshape(x0_t.size(0), -1)) - y.reshape(y.size(0), -1)
126
+ ).reshape(x0_t.size(0), -1), at_next.sqrt()[0, 0, 0, 0], sigma_y, sigma_t, eta).reshape(*x0_t.size())
127
+
128
+ # Eq. 51
129
+ xt_next = at_next.sqrt() * x0_t_hat + A_funcs.Lambda_noise(
130
+ torch.randn_like(x0_t).reshape(x0_t.size(0), -1),
131
+ at_next.sqrt()[0, 0, 0, 0], sigma_y, sigma_t, eta, et.reshape(et.size(0), -1)).reshape(*x0_t.size())
132
+
133
+ x0_preds.append(x0_t.to('cpu'))
134
+ xs.append(xt_next.to('cpu'))
135
+ else: # time-travel back
136
+ next_t = (torch.ones(n) * j).to(x.device)
137
+ at_next = compute_alpha(b, next_t.long())
138
+ x0_t = x0_preds[-1].to('cuda')
139
+
140
+ xt_next = at_next.sqrt() * x0_t + torch.randn_like(x0_t) * (1 - at_next).sqrt()
141
+
142
+ xs.append(xt_next.to('cpu'))
143
+
144
+ # #ablation
145
+ # if i%50==0:
146
+ # os.makedirs('/userhome/wyh/ddnm/debug/x0t', exist_ok=True)
147
+ # tvu.save_image(
148
+ # inverse_data_transform(x0_t[0]),
149
+ # os.path.join('/userhome/wyh/ddnm/debug/x0t', f"x0_t_{i}.png")
150
+ # )
151
+
152
+ # os.makedirs('/userhome/wyh/ddnm/debug/x0_t_hat', exist_ok=True)
153
+ # tvu.save_image(
154
+ # inverse_data_transform(x0_t_hat[0]),
155
+ # os.path.join('/userhome/wyh/ddnm/debug/x0_t_hat', f"x0_t_hat_{i}.png")
156
+ # )
157
+
158
+ # os.makedirs('/userhome/wyh/ddnm/debug/xt_next', exist_ok=True)
159
+ # tvu.save_image(
160
+ # inverse_data_transform(xt_next[0]),
161
+ # os.path.join('/userhome/wyh/ddnm/debug/xt_next', f"xt_next_{i}.png")
162
+ # )
163
+
164
+ return [xs[-1]], [x0_preds[-1]]
165
+
166
+ # form RePaint
167
+ def get_schedule_jump(T_sampling, travel_length, travel_repeat):
168
+
169
+ jumps = {}
170
+ for j in range(0, T_sampling - travel_length, travel_length):
171
+ jumps[j] = travel_repeat - 1
172
+
173
+ t = T_sampling
174
+ ts = []
175
+
176
+ while t >= 1:
177
+ t = t-1
178
+ ts.append(t)
179
+
180
+ if jumps.get(t, 0) > 0:
181
+ jumps[t] = jumps[t] - 1
182
+ for _ in range(travel_length):
183
+ t = t + 1
184
+ ts.append(t)
185
+
186
+ ts.append(-1)
187
+
188
+ _check_times(ts, -1, T_sampling)
189
+
190
+ return ts
191
+
192
+ def _check_times(times, t_0, T_sampling):
193
+ # Check end
194
+ assert times[0] > times[1], (times[0], times[1])
195
+
196
+ # Check beginning
197
+ assert times[-1] == -1, times[-1]
198
+
199
+ # Steplength = 1
200
+ for t_last, t_cur in zip(times[:-1], times[1:]):
201
+ assert abs(t_last - t_cur) == 1, (t_last, t_cur)
202
+
203
+ # Value range
204
+ for t in times:
205
+ assert t >= t_0, (t, t_0)
206
+ assert t <= T_sampling, (t, T_sampling)
src/flair/functions/svd_operators.py ADDED
@@ -0,0 +1,1308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class A_functions:
5
+ """
6
+ A class replacing the SVD of a matrix A, perhaps efficiently.
7
+ All input vectors are of shape (Batch, ...).
8
+ All output vectors are of shape (Batch, DataDimension).
9
+ """
10
+
11
+ def V(self, vec):
12
+ """
13
+ Multiplies the input vector by V
14
+ """
15
+ raise NotImplementedError()
16
+
17
+ def Vt(self, vec):
18
+ """
19
+ Multiplies the input vector by V transposed
20
+ """
21
+ raise NotImplementedError()
22
+
23
+ def U(self, vec):
24
+ """
25
+ Multiplies the input vector by U
26
+ """
27
+ raise NotImplementedError()
28
+
29
+ def Ut(self, vec):
30
+ """
31
+ Multiplies the input vector by U transposed
32
+ """
33
+ raise NotImplementedError()
34
+
35
+ def singulars(self):
36
+ """
37
+ Returns a vector containing the singular values. The shape of the vector should be the same as the smaller dimension (like U)
38
+ """
39
+ raise NotImplementedError()
40
+
41
+ def add_zeros(self, vec):
42
+ """
43
+ Adds trailing zeros to turn a vector from the small dimension (U) to the big dimension (V)
44
+ """
45
+ raise NotImplementedError()
46
+
47
+ def A(self, vec):
48
+ """
49
+ Multiplies the input vector by A
50
+ """
51
+ temp = self.Vt(vec)
52
+ singulars = self.singulars()
53
+ return self.U(singulars * temp[:, :singulars.shape[0]])
54
+
55
+ def At(self, vec):
56
+ """
57
+ Multiplies the input vector by A transposed
58
+ """
59
+ temp = self.Ut(vec)
60
+ singulars = self.singulars()
61
+ return self.V(self.add_zeros(singulars * temp[:, :singulars.shape[0]]))
62
+
63
+ def A_pinv(self, vec):
64
+ """
65
+ Multiplies the input vector by the pseudo inverse of A
66
+ """
67
+ temp = self.Ut(vec)
68
+ singulars = self.singulars()
69
+
70
+ factors = 1. / singulars
71
+ factors[singulars == 0] = 0.
72
+
73
+ # temp[:, :singulars.shape[0]] = temp[:, :singulars.shape[0]] / singulars
74
+ temp[:, :singulars.shape[0]] = temp[:, :singulars.shape[0]] * factors
75
+ return self.V(self.add_zeros(temp))
76
+
77
+ def A_pinv_eta(self, vec, eta):
78
+ """
79
+ Multiplies the input vector by the pseudo inverse of A with factor eta
80
+ """
81
+ temp = self.Ut(vec)
82
+ singulars = self.singulars()
83
+ factors = singulars / (singulars*singulars+eta)
84
+ # print(temp.size(), factors.size(), singulars.size())
85
+ temp[:, :singulars.shape[0]] = temp[:, :singulars.shape[0]] * factors
86
+ return self.V(self.add_zeros(temp))
87
+
88
+ def Lambda(self, vec, a, sigma_y, sigma_t, eta):
89
+ raise NotImplementedError()
90
+
91
+ def Lambda_noise(self, vec, a, sigma_y, sigma_t, eta, epsilon):
92
+ raise NotImplementedError()
93
+
94
+
95
+ # block-wise CS
96
+ class CS(A_functions):
97
+ def __init__(self, channels, img_dim, ratio, device): #ratio = 2 or 4
98
+ self.img_dim = img_dim
99
+ self.channels = channels
100
+ self.y_dim = img_dim // 32
101
+ self.ratio = 32
102
+ A = torch.randn(32**2, 32**2).to(device)
103
+ _, _, self.V_small = torch.svd(A, some=False)
104
+ self.Vt_small = self.V_small.transpose(0, 1)
105
+ self.singulars_small = torch.ones(int(32 * 32 * ratio), device=device)
106
+ self.cs_size = self.singulars_small.size(0)
107
+
108
+ def V(self, vec):
109
+ #reorder the vector back into patches (because singulars are ordered descendingly)
110
+
111
+ temp = vec.clone().reshape(vec.shape[0], -1)
112
+ patches = torch.zeros(vec.size(0), self.channels * self.y_dim ** 2, self.ratio ** 2, device=vec.device)
113
+ patches[:, :, :self.cs_size] = temp[:, :self.channels * self.y_dim ** 2 * self.cs_size].contiguous().reshape(
114
+ vec.size(0), -1, self.cs_size)
115
+ patches[:, :, self.cs_size:] = temp[:, self.channels * self.y_dim ** 2 * self.cs_size:].contiguous().reshape(
116
+ vec.size(0), self.channels * self.y_dim ** 2, -1)
117
+
118
+ #multiply each patch by the small V
119
+ patches = torch.matmul(self.V_small, patches.reshape(-1, self.ratio**2, 1)).reshape(vec.shape[0], self.channels, -1, self.ratio**2)
120
+ #repatch the patches into an image
121
+ patches_orig = patches.reshape(vec.shape[0], self.channels, self.y_dim, self.y_dim, self.ratio, self.ratio)
122
+ recon = patches_orig.permute(0, 1, 2, 4, 3, 5).contiguous()
123
+ recon = recon.reshape(vec.shape[0], self.channels * self.img_dim ** 2)
124
+ return recon
125
+
126
+ def Vt(self, vec):
127
+ #extract flattened patches
128
+ patches = vec.clone().reshape(vec.shape[0], self.channels, self.img_dim, self.img_dim)
129
+ patches = patches.unfold(2, self.ratio, self.ratio).unfold(3, self.ratio, self.ratio)
130
+ patches = patches.contiguous().reshape(vec.shape[0], self.channels, -1, self.ratio**2)
131
+ #multiply each by the small V transposed
132
+ patches = torch.matmul(self.Vt_small, patches.reshape(-1, self.ratio**2, 1)).reshape(vec.shape[0], self.channels, -1, self.ratio**2)
133
+ #reorder the vector to have the first entry first (because singulars are ordered descendingly)
134
+ recon = torch.zeros(vec.shape[0], self.channels * self.img_dim**2, device=vec.device)
135
+ recon[:, :self.channels * self.y_dim ** 2 * self.cs_size] = patches[:, :, :, :self.cs_size].contiguous().reshape(
136
+ vec.shape[0], -1)
137
+ recon[:, self.channels * self.y_dim ** 2 * self.cs_size:] = patches[:, :, :, self.cs_size:].contiguous().reshape(
138
+ vec.shape[0], -1)
139
+ return recon
140
+
141
+ def U(self, vec):
142
+ return vec.clone().reshape(vec.shape[0], -1)
143
+
144
+ def Ut(self, vec): #U is 1x1, so U^T = U
145
+ return vec.clone().reshape(vec.shape[0], -1)
146
+
147
+ def singulars(self):
148
+ return self.singulars_small.repeat(self.channels * self.y_dim**2)
149
+
150
+ def add_zeros(self, vec):
151
+ reshaped = vec.clone().reshape(vec.shape[0], -1)
152
+ temp = torch.zeros((vec.shape[0], self.channels * self.img_dim**2), device=vec.device)
153
+ temp[:, :reshaped.shape[1]] = reshaped
154
+ return temp
155
+
156
+
157
+ def color2gray(x):
158
+ x = x[:, 0:1, :, :] * 0.3333 + x[:, 1:2, :, :] * 0.3334 + x[:, 2:, :, :] * 0.3333
159
+ return x
160
+
161
+
162
+ def gray2color(x):
163
+ base = 0.3333 ** 2 + 0.3334 ** 2 + 0.3333 ** 2
164
+ return torch.stack((x * 0.3333 / base, x * 0.3334 / base, x * 0.3333 / base), 1)
165
+
166
+
167
+ #a memory inefficient implementation for any general degradation A
168
+ class GeneralA(A_functions):
169
+ def mat_by_vec(self, M, v):
170
+ vshape = v.shape[1]
171
+ if len(v.shape) > 2: vshape = vshape * v.shape[2]
172
+ if len(v.shape) > 3: vshape = vshape * v.shape[3]
173
+ return torch.matmul(M, v.view(v.shape[0], vshape,
174
+ 1)).view(v.shape[0], M.shape[0])
175
+
176
+ def __init__(self, A):
177
+ self._U, self._singulars, self._V = torch.svd(A, some=False)
178
+ self._Vt = self._V.transpose(0, 1)
179
+ self._Ut = self._U.transpose(0, 1)
180
+
181
+ ZERO = 1e-3
182
+ self._singulars[self._singulars < ZERO] = 0
183
+ print(len([x.item() for x in self._singulars if x == 0]))
184
+
185
+ def V(self, vec):
186
+ return self.mat_by_vec(self._V, vec.clone())
187
+
188
+ def Vt(self, vec):
189
+ return self.mat_by_vec(self._Vt, vec.clone())
190
+
191
+ def U(self, vec):
192
+ return self.mat_by_vec(self._U, vec.clone())
193
+
194
+ def Ut(self, vec):
195
+ return self.mat_by_vec(self._Ut, vec.clone())
196
+
197
+ def singulars(self):
198
+ return self._singulars
199
+
200
+ def add_zeros(self, vec):
201
+ out = torch.zeros(vec.shape[0], self._V.shape[0], device=vec.device)
202
+ out[:, :self._U.shape[0]] = vec.clone().reshape(vec.shape[0], -1)
203
+ return out
204
+
205
+ #Walsh-Hadamard Compressive Sensing
206
+ class WalshHadamardCS(A_functions):
207
+ def fwht(self, vec): #the Fast Walsh Hadamard Transform is the same as its inverse
208
+ a = vec.reshape(vec.shape[0], self.channels, self.img_dim**2)
209
+ h = 1
210
+ while h < self.img_dim**2:
211
+ a = a.reshape(vec.shape[0], self.channels, -1, h * 2)
212
+ b = a.clone()
213
+ a[:, :, :, :h] = b[:, :, :, :h] + b[:, :, :, h:2*h]
214
+ a[:, :, :, h:2*h] = b[:, :, :, :h] - b[:, :, :, h:2*h]
215
+ h *= 2
216
+ a = a.reshape(vec.shape[0], self.channels, self.img_dim**2) / self.img_dim
217
+ return a
218
+
219
+ def __init__(self, channels, img_dim, ratio, perm, device):
220
+ self.channels = channels
221
+ self.img_dim = img_dim
222
+ self.ratio = ratio
223
+ self.perm = perm
224
+ self._singulars = torch.ones(channels * img_dim**2 // ratio, device=device)
225
+
226
+ def V(self, vec):
227
+ temp = torch.zeros(vec.shape[0], self.channels, self.img_dim**2, device=vec.device)
228
+ temp[:, :, self.perm] = vec.clone().reshape(vec.shape[0], -1, self.channels).permute(0, 2, 1)
229
+ return self.fwht(temp).reshape(vec.shape[0], -1)
230
+
231
+ def Vt(self, vec):
232
+ return self.fwht(vec.clone())[:, :, self.perm].permute(0, 2, 1).reshape(vec.shape[0], -1)
233
+
234
+ def U(self, vec):
235
+ return vec.clone().reshape(vec.shape[0], -1)
236
+
237
+ def Ut(self, vec):
238
+ return vec.clone().reshape(vec.shape[0], -1)
239
+
240
+ def singulars(self):
241
+ return self._singulars
242
+
243
+ def add_zeros(self, vec):
244
+ out = torch.zeros(vec.shape[0], self.channels * self.img_dim**2, device=vec.device)
245
+ out[:, :self.channels * self.img_dim**2 // self.ratio] = vec.clone().reshape(vec.shape[0], -1)
246
+ return out
247
+
248
+ def Lambda(self, vec, a, sigma_y, sigma_t, eta):
249
+ temp_vec = self.fwht(vec.clone())[:, :, self.perm].permute(0, 2, 1).reshape(vec.shape[0], -1)
250
+
251
+ singulars = self._singulars
252
+ lambda_t = torch.ones(self.channels * self.img_dim ** 2, device=vec.device)
253
+ temp = torch.zeros(self.channels * self.img_dim ** 2, device=vec.device)
254
+ temp[:singulars.size(0)] = singulars
255
+ singulars = temp
256
+ inverse_singulars = 1. / singulars
257
+ inverse_singulars[singulars == 0] = 0.
258
+
259
+ if a != 0 and sigma_y != 0:
260
+ change_index = (sigma_t < a * sigma_y * inverse_singulars) * 1.0
261
+ lambda_t = lambda_t * (-change_index + 1.0) + change_index * (
262
+ singulars * sigma_t * (1 - eta ** 2) ** 0.5 / a / sigma_y)
263
+
264
+ lambda_t = lambda_t.reshape(1, -1)
265
+ temp_vec = temp_vec * lambda_t
266
+
267
+ temp_out = torch.zeros(vec.shape[0], self.channels, self.img_dim ** 2, device=vec.device)
268
+ temp_out[:, :, self.perm] = temp_vec.clone().reshape(vec.shape[0], -1, self.channels).permute(0, 2, 1)
269
+ return self.fwht(temp_out).reshape(vec.shape[0], -1)
270
+
271
+ def Lambda_noise(self, vec, a, sigma_y, sigma_t, eta, epsilon):
272
+ temp_vec = vec.clone().reshape(
273
+ vec.shape[0], self.channels, self.img_dim ** 2)[:, :, self.perm].permute(0, 2, 1).reshape(vec.shape[0], -1)
274
+ temp_eps = epsilon.clone().reshape(
275
+ vec.shape[0], self.channels, self.img_dim ** 2)[:, :, self.perm].permute(0, 2, 1).reshape(vec.shape[0], -1)
276
+
277
+ d1_t = torch.ones(self.channels * self.img_dim ** 2, device=vec.device) * sigma_t * eta
278
+ d2_t = torch.ones(self.channels * self.img_dim ** 2, device=vec.device) * sigma_t * (1 - eta ** 2) ** 0.5
279
+
280
+ singulars = self._singulars
281
+ temp = torch.zeros(self.channels * self.img_dim ** 2, device=vec.device)
282
+ temp[:singulars.size(0)] = singulars
283
+ singulars = temp
284
+ inverse_singulars = 1. / singulars
285
+ inverse_singulars[singulars == 0] = 0.
286
+
287
+ if a != 0 and sigma_y != 0:
288
+ change_index = (sigma_t < a * sigma_y * inverse_singulars) * 1.0
289
+ d1_t = d1_t * (-change_index + 1.0) + change_index * sigma_t * eta
290
+ d2_t = d2_t * (-change_index + 1.0)
291
+
292
+ change_index = (sigma_t > a * sigma_y * inverse_singulars) * 1.0
293
+ d1_t = d1_t * (-change_index + 1.0) + torch.sqrt(
294
+ change_index * (sigma_t ** 2 - a ** 2 * sigma_y ** 2 * inverse_singulars ** 2))
295
+ d2_t = d2_t * (-change_index + 1.0)
296
+
297
+ change_index = (singulars == 0) * 1.0
298
+ d1_t = d1_t * (-change_index + 1.0) + change_index * sigma_t * eta
299
+ d2_t = d2_t * (-change_index + 1.0) + change_index * sigma_t * (1 - eta ** 2) ** 0.5
300
+
301
+ d1_t = d1_t.reshape(1, -1)
302
+ d2_t = d2_t.reshape(1, -1)
303
+
304
+ temp_vec = temp_vec * d1_t
305
+ temp_eps = temp_eps * d2_t
306
+
307
+ temp_out_vec = torch.zeros(vec.shape[0], self.channels, self.img_dim ** 2, device=vec.device)
308
+ temp_out_vec[:, :, self.perm] = temp_vec.clone().reshape(vec.shape[0], -1, self.channels).permute(0, 2, 1)
309
+ temp_out_vec = self.fwht(temp_out_vec).reshape(vec.shape[0], -1)
310
+
311
+ temp_out_eps = torch.zeros(vec.shape[0], self.channels, self.img_dim ** 2, device=vec.device)
312
+ temp_out_eps[:, :, self.perm] = temp_eps.clone().reshape(vec.shape[0], -1, self.channels).permute(0, 2, 1)
313
+ temp_out_eps = self.fwht(temp_out_eps).reshape(vec.shape[0], -1)
314
+
315
+ return temp_out_vec + temp_out_eps
316
+
317
+
318
+ #Inpainting
319
+ class Inpainting(A_functions):
320
+ def __init__(self, channels, img_dim, missing_indices, device):
321
+ self.channels = channels
322
+ self.img_dim = img_dim
323
+ self._singulars = torch.ones(channels * img_dim**2 - missing_indices.shape[0]).to(device)
324
+ self.missing_indices = missing_indices
325
+ self.kept_indices = torch.Tensor([i for i in range(channels * img_dim**2) if i not in missing_indices]).to(device).long()
326
+
327
+ def V(self, vec):
328
+ temp = vec.clone().reshape(vec.shape[0], -1)
329
+ out = torch.zeros_like(temp)
330
+ out[:, self.kept_indices] = temp[:, :self.kept_indices.shape[0]]
331
+ out[:, self.missing_indices] = temp[:, self.kept_indices.shape[0]:]
332
+ return out.reshape(vec.shape[0], -1, self.channels).permute(0, 2, 1).reshape(vec.shape[0], -1)
333
+
334
+ def Vt(self, vec):
335
+ temp = vec.clone().reshape(vec.shape[0], self.channels, -1).permute(0, 2, 1).reshape(vec.shape[0], -1)
336
+ out = torch.zeros_like(temp)
337
+ out[:, :self.kept_indices.shape[0]] = temp[:, self.kept_indices]
338
+ out[:, self.kept_indices.shape[0]:] = temp[:, self.missing_indices]
339
+ return out
340
+
341
+ def U(self, vec):
342
+ return vec.clone().reshape(vec.shape[0], -1)
343
+
344
+ def Ut(self, vec):
345
+ return vec.clone().reshape(vec.shape[0], -1)
346
+
347
+ def singulars(self):
348
+ return self._singulars
349
+
350
+ def add_zeros(self, vec):
351
+ temp = torch.zeros((vec.shape[0], self.channels * self.img_dim**2), device=vec.device)
352
+ reshaped = vec.clone().reshape(vec.shape[0], -1)
353
+ temp[:, :reshaped.shape[1]] = reshaped
354
+ return temp
355
+
356
+ def Lambda(self, vec, a, sigma_y, sigma_t, eta):
357
+
358
+ temp = vec.clone().reshape(vec.shape[0], self.channels, -1).permute(0, 2, 1).reshape(vec.shape[0], -1)
359
+ out = torch.zeros_like(temp)
360
+ out[:, :self.kept_indices.shape[0]] = temp[:, self.kept_indices]
361
+ out[:, self.kept_indices.shape[0]:] = temp[:, self.missing_indices]
362
+
363
+ singulars = self._singulars
364
+ lambda_t = torch.ones(temp.size(1), device=vec.device)
365
+ temp_singulars = torch.zeros(temp.size(1), device=vec.device)
366
+ temp_singulars[:singulars.size(0)] = singulars
367
+ singulars = temp_singulars
368
+ inverse_singulars = 1. / singulars
369
+ inverse_singulars[singulars == 0] = 0.
370
+
371
+ if a != 0 and sigma_y != 0:
372
+ change_index = (sigma_t < a * sigma_y * inverse_singulars) * 1.0
373
+ lambda_t = lambda_t * (-change_index + 1.0) + change_index * (
374
+ singulars * sigma_t * (1 - eta ** 2) ** 0.5 / a / sigma_y)
375
+
376
+ lambda_t = lambda_t.reshape(1, -1)
377
+ out = out * lambda_t
378
+
379
+ result = torch.zeros_like(temp)
380
+ result[:, self.kept_indices] = out[:, :self.kept_indices.shape[0]]
381
+ result[:, self.missing_indices] = out[:, self.kept_indices.shape[0]:]
382
+ return result.reshape(vec.shape[0], -1, self.channels).permute(0, 2, 1).reshape(vec.shape[0], -1)
383
+
384
+ def Lambda_noise(self, vec, a, sigma_y, sigma_t, eta, epsilon):
385
+ temp_vec = vec.clone().reshape(vec.shape[0], self.channels, -1).permute(0, 2, 1).reshape(vec.shape[0], -1)
386
+ out_vec = torch.zeros_like(temp_vec)
387
+ out_vec[:, :self.kept_indices.shape[0]] = temp_vec[:, self.kept_indices]
388
+ out_vec[:, self.kept_indices.shape[0]:] = temp_vec[:, self.missing_indices]
389
+
390
+ temp_eps = epsilon.clone().reshape(vec.shape[0], self.channels, -1).permute(0, 2, 1).reshape(vec.shape[0], -1)
391
+ out_eps = torch.zeros_like(temp_eps)
392
+ out_eps[:, :self.kept_indices.shape[0]] = temp_eps[:, self.kept_indices]
393
+ out_eps[:, self.kept_indices.shape[0]:] = temp_eps[:, self.missing_indices]
394
+
395
+ singulars = self._singulars
396
+ d1_t = torch.ones(temp_vec.size(1), device=vec.device) * sigma_t * eta
397
+ d2_t = torch.ones(temp_vec.size(1), device=vec.device) * sigma_t * (1 - eta ** 2) ** 0.5
398
+
399
+ temp_singulars = torch.zeros(temp_vec.size(1), device=vec.device)
400
+ temp_singulars[:singulars.size(0)] = singulars
401
+ singulars = temp_singulars
402
+ inverse_singulars = 1. / singulars
403
+ inverse_singulars[singulars == 0] = 0.
404
+
405
+ if a != 0 and sigma_y != 0:
406
+ change_index = (sigma_t < a * sigma_y * inverse_singulars) * 1.0
407
+ d1_t = d1_t * (-change_index + 1.0) + change_index * sigma_t * eta
408
+ d2_t = d2_t * (-change_index + 1.0)
409
+
410
+ change_index = (sigma_t > a * sigma_y * inverse_singulars) * 1.0
411
+ d1_t = d1_t * (-change_index + 1.0) + torch.sqrt(
412
+ change_index * (sigma_t ** 2 - a ** 2 * sigma_y ** 2 * inverse_singulars ** 2))
413
+ d2_t = d2_t * (-change_index + 1.0)
414
+
415
+ change_index = (singulars == 0) * 1.0
416
+ d1_t = d1_t * (-change_index + 1.0) + change_index * sigma_t * eta
417
+ d2_t = d2_t * (-change_index + 1.0) + change_index * sigma_t * (1 - eta ** 2) ** 0.5
418
+
419
+ d1_t = d1_t.reshape(1, -1)
420
+ d2_t = d2_t.reshape(1, -1)
421
+ out_vec = out_vec * d1_t
422
+ out_eps = out_eps * d2_t
423
+
424
+ result_vec = torch.zeros_like(temp_vec)
425
+ result_vec[:, self.kept_indices] = out_vec[:, :self.kept_indices.shape[0]]
426
+ result_vec[:, self.missing_indices] = out_vec[:, self.kept_indices.shape[0]:]
427
+ result_vec = result_vec.reshape(vec.shape[0], -1, self.channels).permute(0, 2, 1).reshape(vec.shape[0], -1)
428
+
429
+ result_eps = torch.zeros_like(temp_eps)
430
+ result_eps[:, self.kept_indices] = out_eps[:, :self.kept_indices.shape[0]]
431
+ result_eps[:, self.missing_indices] = out_eps[:, self.kept_indices.shape[0]:]
432
+ result_eps = result_eps.reshape(vec.shape[0], -1, self.channels).permute(0, 2, 1).reshape(vec.shape[0], -1)
433
+
434
+ return result_vec + result_eps
435
+
436
+ #Denoising
437
+ class Denoising(A_functions):
438
+ def __init__(self, channels, img_dim, device):
439
+ self._singulars = torch.ones(channels * img_dim**2, device=device)
440
+
441
+ def V(self, vec):
442
+ return vec.clone().reshape(vec.shape[0], -1)
443
+
444
+ def Vt(self, vec):
445
+ return vec.clone().reshape(vec.shape[0], -1)
446
+
447
+ def U(self, vec):
448
+ return vec.clone().reshape(vec.shape[0], -1)
449
+
450
+ def Ut(self, vec):
451
+ return vec.clone().reshape(vec.shape[0], -1)
452
+
453
+ def singulars(self):
454
+ return self._singulars
455
+
456
+ def add_zeros(self, vec):
457
+ return vec.clone().reshape(vec.shape[0], -1)
458
+
459
+ def Lambda(self, vec, a, sigma_y, sigma_t, eta):
460
+ if sigma_t < a * sigma_y:
461
+ factor = (sigma_t * (1 - eta ** 2) ** 0.5 / a / sigma_y).item()
462
+ return vec * factor
463
+ else:
464
+ return vec
465
+
466
+ def Lambda_noise(self, vec, a, sigma_y, sigma_t, eta, epsilon):
467
+ if sigma_t >= a * sigma_y:
468
+ factor = torch.sqrt(sigma_t ** 2 - a ** 2 * sigma_y ** 2).item()
469
+ return vec * factor
470
+ else:
471
+ return vec * sigma_t * eta
472
+
473
+ #Super Resolution
474
+ class SuperResolution(A_functions):
475
+ def __init__(self, channels, img_dim, ratio, device): #ratio = 2 or 4
476
+ assert img_dim % ratio == 0
477
+ self.img_dim = img_dim
478
+ self.channels = channels
479
+ self.y_dim = img_dim // ratio
480
+ self.ratio = ratio
481
+ A = torch.Tensor([[1 / ratio**2] * ratio**2]).to(device)
482
+ self.U_small, self.singulars_small, self.V_small = torch.svd(A, some=False)
483
+ self.Vt_small = self.V_small.transpose(0, 1)
484
+
485
+ def V(self, vec):
486
+ #reorder the vector back into patches (because singulars are ordered descendingly)
487
+ temp = vec.clone().reshape(vec.shape[0], -1)
488
+ patches = torch.zeros(vec.shape[0], self.channels, self.y_dim**2, self.ratio**2, device=vec.device)
489
+ patches[:, :, :, 0] = temp[:, :self.channels * self.y_dim**2].view(vec.shape[0], self.channels, -1)
490
+ for idx in range(self.ratio**2-1):
491
+ patches[:, :, :, idx+1] = temp[:, (self.channels*self.y_dim**2+idx)::self.ratio**2-1].view(vec.shape[0], self.channels, -1)
492
+ #multiply each patch by the small V
493
+ patches = torch.matmul(self.V_small, patches.reshape(-1, self.ratio**2, 1)).reshape(vec.shape[0], self.channels, -1, self.ratio**2)
494
+ #repatch the patches into an image
495
+ patches_orig = patches.reshape(vec.shape[0], self.channels, self.y_dim, self.y_dim, self.ratio, self.ratio)
496
+ recon = patches_orig.permute(0, 1, 2, 4, 3, 5).contiguous()
497
+ recon = recon.reshape(vec.shape[0], self.channels * self.img_dim ** 2)
498
+ return recon
499
+
500
+ def Vt(self, vec):
501
+ #extract flattened patches
502
+ patches = vec.clone().reshape(vec.shape[0], self.channels, self.img_dim, self.img_dim)
503
+ patches = patches.unfold(2, self.ratio, self.ratio).unfold(3, self.ratio, self.ratio)
504
+ unfold_shape = patches.shape
505
+ patches = patches.contiguous().reshape(vec.shape[0], self.channels, -1, self.ratio**2)
506
+ #multiply each by the small V transposed
507
+ patches = torch.matmul(self.Vt_small, patches.reshape(-1, self.ratio**2, 1)).reshape(vec.shape[0], self.channels, -1, self.ratio**2)
508
+ #reorder the vector to have the first entry first (because singulars are ordered descendingly)
509
+ recon = torch.zeros(vec.shape[0], self.channels * self.img_dim**2, device=vec.device)
510
+ recon[:, :self.channels * self.y_dim**2] = patches[:, :, :, 0].view(vec.shape[0], self.channels * self.y_dim**2)
511
+ for idx in range(self.ratio**2-1):
512
+ recon[:, (self.channels*self.y_dim**2+idx)::self.ratio**2-1] = patches[:, :, :, idx+1].view(vec.shape[0], self.channels * self.y_dim**2)
513
+ return recon
514
+
515
+ def U(self, vec):
516
+ return self.U_small[0, 0] * vec.clone().reshape(vec.shape[0], -1)
517
+
518
+ def Ut(self, vec): #U is 1x1, so U^T = U
519
+ return self.U_small[0, 0] * vec.clone().reshape(vec.shape[0], -1)
520
+
521
+ def singulars(self):
522
+ return self.singulars_small.repeat(self.channels * self.y_dim**2)
523
+
524
+ def add_zeros(self, vec):
525
+ reshaped = vec.clone().reshape(vec.shape[0], -1)
526
+ temp = torch.zeros((vec.shape[0], reshaped.shape[1] * self.ratio**2), device=vec.device)
527
+ temp[:, :reshaped.shape[1]] = reshaped
528
+ return temp
529
+
530
+ def Lambda(self, vec, a, sigma_y, sigma_t, eta):
531
+ singulars = self.singulars_small
532
+
533
+ patches = vec.clone().reshape(vec.shape[0], self.channels, self.img_dim, self.img_dim)
534
+ patches = patches.unfold(2, self.ratio, self.ratio).unfold(3, self.ratio, self.ratio)
535
+ patches = patches.contiguous().reshape(vec.shape[0], self.channels, -1, self.ratio ** 2)
536
+
537
+ patches = torch.matmul(self.Vt_small, patches.reshape(-1, self.ratio**2, 1)).reshape(vec.shape[0], self.channels, -1, self.ratio**2)
538
+
539
+ lambda_t = torch.ones(self.ratio ** 2, device=vec.device)
540
+
541
+ temp = torch.zeros(self.ratio ** 2, device=vec.device)
542
+ temp[:singulars.size(0)] = singulars
543
+ singulars = temp
544
+ inverse_singulars = 1. / singulars
545
+ inverse_singulars[singulars == 0] = 0.
546
+
547
+ if a != 0 and sigma_y != 0:
548
+ change_index = (sigma_t < a * sigma_y * inverse_singulars) * 1.0
549
+ lambda_t = lambda_t * (-change_index + 1.0) + change_index * (singulars * sigma_t * (1 - eta ** 2) ** 0.5 / a / sigma_y)
550
+
551
+ lambda_t = lambda_t.reshape(1, 1, 1, -1)
552
+ # print("lambda_t:", lambda_t)
553
+ # print("V:", self.V_small)
554
+ # print(lambda_t.size(), self.V_small.size())
555
+ # print("Sigma_t:", torch.matmul(torch.matmul(self.V_small, torch.diag(lambda_t.reshape(-1))), self.Vt_small))
556
+ patches = patches * lambda_t
557
+
558
+
559
+ patches = torch.matmul(self.V_small, patches.reshape(-1, self.ratio**2, 1))
560
+
561
+ patches = patches.reshape(vec.shape[0], self.channels, self.y_dim, self.y_dim, self.ratio, self.ratio)
562
+ patches = patches.permute(0, 1, 2, 4, 3, 5).contiguous()
563
+ patches = patches.reshape(vec.shape[0], self.channels * self.img_dim ** 2)
564
+
565
+ return patches
566
+
567
+ def Lambda_noise(self, vec, a, sigma_y, sigma_t, eta, epsilon):
568
+ singulars = self.singulars_small
569
+
570
+ patches_vec = vec.clone().reshape(vec.shape[0], self.channels, self.img_dim, self.img_dim)
571
+ patches_vec = patches_vec.unfold(2, self.ratio, self.ratio).unfold(3, self.ratio, self.ratio)
572
+ patches_vec = patches_vec.contiguous().reshape(vec.shape[0], self.channels, -1, self.ratio ** 2)
573
+
574
+ patches_eps = epsilon.clone().reshape(vec.shape[0], self.channels, self.img_dim, self.img_dim)
575
+ patches_eps = patches_eps.unfold(2, self.ratio, self.ratio).unfold(3, self.ratio, self.ratio)
576
+ patches_eps = patches_eps.contiguous().reshape(vec.shape[0], self.channels, -1, self.ratio ** 2)
577
+
578
+ d1_t = torch.ones(self.ratio ** 2, device=vec.device) * sigma_t * eta
579
+ d2_t = torch.ones(self.ratio ** 2, device=vec.device) * sigma_t * (1 - eta ** 2) ** 0.5
580
+
581
+ temp = torch.zeros(self.ratio ** 2, device=vec.device)
582
+ temp[:singulars.size(0)] = singulars
583
+ singulars = temp
584
+ inverse_singulars = 1. / singulars
585
+ inverse_singulars[singulars == 0] = 0.
586
+
587
+ if a != 0 and sigma_y != 0:
588
+
589
+ change_index = (sigma_t < a * sigma_y * inverse_singulars) * 1.0
590
+ d1_t = d1_t * (-change_index + 1.0) + change_index * sigma_t * eta
591
+ d2_t = d2_t * (-change_index + 1.0)
592
+
593
+ change_index = (sigma_t > a * sigma_y * inverse_singulars) * 1.0
594
+ d1_t = d1_t * (-change_index + 1.0) + torch.sqrt(change_index * (sigma_t ** 2 - a ** 2 * sigma_y ** 2 * inverse_singulars ** 2))
595
+ d2_t = d2_t * (-change_index + 1.0)
596
+
597
+ change_index = (singulars == 0) * 1.0
598
+ d1_t = d1_t * (-change_index + 1.0) + change_index * sigma_t * eta
599
+ d2_t = d2_t * (-change_index + 1.0) + change_index * sigma_t * (1 - eta ** 2) ** 0.5
600
+
601
+ d1_t = d1_t.reshape(1, 1, 1, -1)
602
+ d2_t = d2_t.reshape(1, 1, 1, -1)
603
+ patches_vec = patches_vec * d1_t
604
+ patches_eps = patches_eps * d2_t
605
+
606
+ patches_vec = torch.matmul(self.V_small, patches_vec.reshape(-1, self.ratio**2, 1))
607
+
608
+ patches_vec = patches_vec.reshape(vec.shape[0], self.channels, self.y_dim, self.y_dim, self.ratio, self.ratio)
609
+ patches_vec = patches_vec.permute(0, 1, 2, 4, 3, 5).contiguous()
610
+ patches_vec = patches_vec.reshape(vec.shape[0], self.channels * self.img_dim ** 2)
611
+
612
+ patches_eps = torch.matmul(self.V_small, patches_eps.reshape(-1, self.ratio**2, 1))
613
+
614
+ patches_eps = patches_eps.reshape(vec.shape[0], self.channels, self.y_dim, self.y_dim, self.ratio, self.ratio)
615
+ patches_eps = patches_eps.permute(0, 1, 2, 4, 3, 5).contiguous()
616
+ patches_eps = patches_eps.reshape(vec.shape[0], self.channels * self.img_dim ** 2)
617
+
618
+ return patches_vec + patches_eps
619
+
620
+ class SuperResolutionGeneral(SuperResolution):
621
+ def __init__(self, channels, imgH, imgW, ratio, device): #ratio = 2 or 4
622
+ assert imgH % ratio == 0 and imgW % ratio == 0
623
+ self.imgH = imgH
624
+ self.imgW = imgW
625
+ self.channels = channels
626
+ self.yH = imgH // ratio
627
+ self.yW = imgW // ratio
628
+ self.ratio = ratio
629
+ A = torch.Tensor([[1 / ratio**2] * ratio**2]).to(device)
630
+ self.U_small, self.singulars_small, self.V_small = torch.svd(A, some=False)
631
+ self.Vt_small = self.V_small.transpose(0, 1)
632
+
633
+ def V(self, vec):
634
+ #reorder the vector back into patches (because singulars are ordered descendingly)
635
+ temp = vec.clone().reshape(vec.shape[0], -1)
636
+ patches = torch.zeros(vec.shape[0], self.channels, self.yH*self.yW, self.ratio**2, device=vec.device)
637
+ patches[:, :, :, 0] = temp[:, :self.channels * self.yH*self.yW].view(vec.shape[0], self.channels, -1)
638
+ for idx in range(self.ratio**2-1):
639
+ patches[:, :, :, idx+1] = temp[:, (self.channels*self.yH*self.yW+idx)::self.ratio**2-1].view(vec.shape[0], self.channels, -1)
640
+ #multiply each patch by the small V
641
+ patches = torch.matmul(self.V_small, patches.reshape(-1, self.ratio**2, 1)).reshape(vec.shape[0], self.channels, -1, self.ratio**2)
642
+ #repatch the patches into an image
643
+ patches_orig = patches.reshape(vec.shape[0], self.channels, self.yH, self.yW, self.ratio, self.ratio)
644
+ recon = patches_orig.permute(0, 1, 2, 4, 3, 5).contiguous()
645
+ recon = recon.reshape(vec.shape[0], self.channels * self.imgH * self.imgW)
646
+ return recon
647
+
648
+ def Vt(self, vec):
649
+ #extract flattened patches
650
+ patches = vec.clone().reshape(vec.shape[0], self.channels, self.imgH, self.imgW)
651
+ patches = patches.unfold(2, self.ratio, self.ratio).unfold(3, self.ratio, self.ratio)
652
+ unfold_shape = patches.shape
653
+ patches = patches.contiguous().reshape(vec.shape[0], self.channels, -1, self.ratio**2)
654
+ #multiply each by the small V transposed
655
+ patches = torch.matmul(self.Vt_small, patches.reshape(-1, self.ratio**2, 1)).reshape(vec.shape[0], self.channels, -1, self.ratio**2)
656
+ #reorder the vector to have the first entry first (because singulars are ordered descendingly)
657
+ recon = torch.zeros(vec.shape[0], self.channels * self.imgH*self.imgW, device=vec.device)
658
+ recon[:, :self.channels * self.yH*self.yW] = patches[:, :, :, 0].view(vec.shape[0], self.channels * self.yH*self.yW)
659
+ for idx in range(self.ratio**2-1):
660
+ recon[:, (self.channels*self.yH*self.yW+idx)::self.ratio**2-1] = patches[:, :, :, idx+1].view(vec.shape[0], self.channels * self.yH*self.yW)
661
+ return recon
662
+
663
+ def U(self, vec):
664
+ return self.U_small[0, 0] * vec.clone().reshape(vec.shape[0], -1)
665
+
666
+ def Ut(self, vec): #U is 1x1, so U^T = U
667
+ return self.U_small[0, 0] * vec.clone().reshape(vec.shape[0], -1)
668
+
669
+ def singulars(self):
670
+ return self.singulars_small.repeat(self.channels * self.yH*self.yW)
671
+
672
+ def add_zeros(self, vec):
673
+ reshaped = vec.clone().reshape(vec.shape[0], -1)
674
+ temp = torch.zeros((vec.shape[0], reshaped.shape[1] * self.ratio**2), device=vec.device)
675
+ temp[:, :reshaped.shape[1]] = reshaped
676
+ return temp
677
+
678
+ def Lambda(self, vec, a, sigma_y, sigma_t, eta):
679
+ singulars = self.singulars_small
680
+
681
+ patches = vec.clone().reshape(vec.shape[0], self.channels, self.imgH, self.imgW)
682
+ patches = patches.unfold(2, self.ratio, self.ratio).unfold(3, self.ratio, self.ratio)
683
+ patches = patches.contiguous().reshape(vec.shape[0], self.channels, -1, self.ratio ** 2)
684
+
685
+ patches = torch.matmul(self.Vt_small, patches.reshape(-1, self.ratio**2, 1)).reshape(vec.shape[0], self.channels, -1, self.ratio**2)
686
+
687
+ lambda_t = torch.ones(self.ratio ** 2, device=vec.device)
688
+
689
+ temp = torch.zeros(self.ratio ** 2, device=vec.device)
690
+ temp[:singulars.size(0)] = singulars
691
+ singulars = temp
692
+ inverse_singulars = 1. / singulars
693
+ inverse_singulars[singulars == 0] = 0.
694
+
695
+ if a != 0 and sigma_y != 0:
696
+ change_index = (sigma_t < a * sigma_y * inverse_singulars) * 1.0
697
+ lambda_t = lambda_t * (-change_index + 1.0) + change_index * (singulars * sigma_t * (1 - eta ** 2) ** 0.5 / a / sigma_y)
698
+
699
+ lambda_t = lambda_t.reshape(1, 1, 1, -1)
700
+ # print("lambda_t:", lambda_t)
701
+ # print("V:", self.V_small)
702
+ # print(lambda_t.size(), self.V_small.size())
703
+ # print("Sigma_t:", torch.matmul(torch.matmul(self.V_small, torch.diag(lambda_t.reshape(-1))), self.Vt_small))
704
+ patches = patches * lambda_t
705
+
706
+
707
+ patches = torch.matmul(self.V_small, patches.reshape(-1, self.ratio**2, 1))
708
+
709
+ patches = patches.reshape(vec.shape[0], self.channels, self.yH, self.yW, self.ratio, self.ratio)
710
+ patches = patches.permute(0, 1, 2, 4, 3, 5).contiguous()
711
+ patches = patches.reshape(vec.shape[0], self.channels * self.imgH * self.imgW)
712
+
713
+ return patches
714
+
715
+ def Lambda_noise(self, vec, a, sigma_y, sigma_t, eta, epsilon):
716
+ singulars = self.singulars_small
717
+
718
+ patches_vec = vec.clone().reshape(vec.shape[0], self.channels, self.imgH, self.imgW)
719
+ patches_vec = patches_vec.unfold(2, self.ratio, self.ratio).unfold(3, self.ratio, self.ratio)
720
+ patches_vec = patches_vec.contiguous().reshape(vec.shape[0], self.channels, -1, self.ratio ** 2)
721
+
722
+ patches_eps = epsilon.clone().reshape(vec.shape[0], self.channels, self.imgH, self.imgW)
723
+ patches_eps = patches_eps.unfold(2, self.ratio, self.ratio).unfold(3, self.ratio, self.ratio)
724
+ patches_eps = patches_eps.contiguous().reshape(vec.shape[0], self.channels, -1, self.ratio ** 2)
725
+
726
+ d1_t = torch.ones(self.ratio ** 2, device=vec.device) * sigma_t * eta
727
+ d2_t = torch.ones(self.ratio ** 2, device=vec.device) * sigma_t * (1 - eta ** 2) ** 0.5
728
+
729
+ temp = torch.zeros(self.ratio ** 2, device=vec.device)
730
+ temp[:singulars.size(0)] = singulars
731
+ singulars = temp
732
+ inverse_singulars = 1. / singulars
733
+ inverse_singulars[singulars == 0] = 0.
734
+
735
+ if a != 0 and sigma_y != 0:
736
+
737
+ change_index = (sigma_t < a * sigma_y * inverse_singulars) * 1.0
738
+ d1_t = d1_t * (-change_index + 1.0) + change_index * sigma_t * eta
739
+ d2_t = d2_t * (-change_index + 1.0)
740
+
741
+ change_index = (sigma_t > a * sigma_y * inverse_singulars) * 1.0
742
+ d1_t = d1_t * (-change_index + 1.0) + torch.sqrt(change_index * (sigma_t ** 2 - a ** 2 * sigma_y ** 2 * inverse_singulars ** 2))
743
+ d2_t = d2_t * (-change_index + 1.0)
744
+
745
+ change_index = (singulars == 0) * 1.0
746
+ d1_t = d1_t * (-change_index + 1.0) + change_index * sigma_t * eta
747
+ d2_t = d2_t * (-change_index + 1.0) + change_index * sigma_t * (1 - eta ** 2) ** 0.5
748
+
749
+ d1_t = d1_t.reshape(1, 1, 1, -1)
750
+ d2_t = d2_t.reshape(1, 1, 1, -1)
751
+ patches_vec = patches_vec * d1_t
752
+ patches_eps = patches_eps * d2_t
753
+
754
+ patches_vec = torch.matmul(self.V_small, patches_vec.reshape(-1, self.ratio**2, 1))
755
+
756
+ patches_vec = patches_vec.reshape(vec.shape[0], self.channels, self.yH, self.yW, self.ratio, self.ratio)
757
+ patches_vec = patches_vec.permute(0, 1, 2, 4, 3, 5).contiguous()
758
+ patches_vec = patches_vec.reshape(vec.shape[0], self.channels * self.imgH * self.imgW)
759
+
760
+ patches_eps = torch.matmul(self.V_small, patches_eps.reshape(-1, self.ratio**2, 1))
761
+
762
+ patches_eps = patches_eps.reshape(vec.shape[0], self.channels, self.yH, self.yW, self.ratio, self.ratio)
763
+ patches_eps = patches_eps.permute(0, 1, 2, 4, 3, 5).contiguous()
764
+ patches_eps = patches_eps.reshape(vec.shape[0], self.channels * self.imgH * self.imgW)
765
+
766
+ return patches_vec + patches_eps
767
+
768
+ #Colorization
769
+ class Colorization(A_functions):
770
+ def __init__(self, img_dim, device):
771
+ self.channels = 3
772
+ self.img_dim = img_dim
773
+ #Do the SVD for the per-pixel matrix
774
+ A = torch.Tensor([[0.3333, 0.3334, 0.3333]]).to(device)
775
+ self.U_small, self.singulars_small, self.V_small = torch.svd(A, some=False)
776
+ self.Vt_small = self.V_small.transpose(0, 1)
777
+
778
+ def V(self, vec):
779
+ #get the needles
780
+ needles = vec.clone().reshape(vec.shape[0], self.channels, -1).permute(0, 2, 1) #shape: B, WA, C'
781
+ #multiply each needle by the small V
782
+ needles = torch.matmul(self.V_small, needles.reshape(-1, self.channels, 1)).reshape(vec.shape[0], -1, self.channels) #shape: B, WA, C
783
+ #permute back to vector representation
784
+ recon = needles.permute(0, 2, 1) #shape: B, C, WA
785
+ return recon.reshape(vec.shape[0], -1)
786
+
787
+ def Vt(self, vec):
788
+ #get the needles
789
+ needles = vec.clone().reshape(vec.shape[0], self.channels, -1).permute(0, 2, 1) #shape: B, WA, C
790
+ #multiply each needle by the small V transposed
791
+ needles = torch.matmul(self.Vt_small, needles.reshape(-1, self.channels, 1)).reshape(vec.shape[0], -1, self.channels) #shape: B, WA, C'
792
+ #reorder the vector so that the first entry of each needle is at the top
793
+ recon = needles.permute(0, 2, 1).reshape(vec.shape[0], -1)
794
+ return recon
795
+
796
+ def U(self, vec):
797
+ return self.U_small[0, 0] * vec.clone().reshape(vec.shape[0], -1)
798
+
799
+ def Ut(self, vec): #U is 1x1, so U^T = U
800
+ return self.U_small[0, 0] * vec.clone().reshape(vec.shape[0], -1)
801
+
802
+ def singulars(self):
803
+ return self.singulars_small.repeat(self.img_dim**2)
804
+
805
+ def add_zeros(self, vec):
806
+ reshaped = vec.clone().reshape(vec.shape[0], -1)
807
+ temp = torch.zeros((vec.shape[0], self.channels * self.img_dim**2), device=vec.device)
808
+ temp[:, :self.img_dim**2] = reshaped
809
+ return temp
810
+
811
+ def Lambda(self, vec, a, sigma_y, sigma_t, eta):
812
+ needles = vec.clone().reshape(vec.shape[0], self.channels, -1).permute(0, 2, 1)
813
+
814
+ needles = torch.matmul(self.Vt_small, needles.reshape(-1, self.channels, 1)).reshape(vec.shape[0], -1, self.channels)
815
+
816
+ singulars = self.singulars_small
817
+ lambda_t = torch.ones(self.channels, device=vec.device)
818
+ temp = torch.zeros(self.channels, device=vec.device)
819
+ temp[:singulars.size(0)] = singulars
820
+ singulars = temp
821
+ inverse_singulars = 1. / singulars
822
+ inverse_singulars[singulars == 0] = 0.
823
+
824
+ if a != 0 and sigma_y != 0:
825
+ change_index = (sigma_t < a * sigma_y * inverse_singulars) * 1.0
826
+ lambda_t = lambda_t * (-change_index + 1.0) + change_index * (
827
+ singulars * sigma_t * (1 - eta ** 2) ** 0.5 / a / sigma_y)
828
+
829
+ lambda_t = lambda_t.reshape(1, 1, self.channels)
830
+ needles = needles * lambda_t
831
+
832
+ needles = torch.matmul(self.V_small, needles.reshape(-1, self.channels, 1)).reshape(vec.shape[0], -1, self.channels)
833
+
834
+ recon = needles.permute(0, 2, 1).reshape(vec.shape[0], -1)
835
+ return recon
836
+
837
+ def Lambda_noise(self, vec, a, sigma_y, sigma_t, eta, epsilon):
838
+ singulars = self.singulars_small
839
+
840
+ needles_vec = vec.clone().reshape(vec.shape[0], self.channels, -1).permute(0, 2, 1)
841
+ needles_epsilon = epsilon.clone().reshape(vec.shape[0], self.channels, -1).permute(0, 2, 1)
842
+
843
+ d1_t = torch.ones(self.channels, device=vec.device) * sigma_t * eta
844
+ d2_t = torch.ones(self.channels, device=vec.device) * sigma_t * (1 - eta ** 2) ** 0.5
845
+
846
+ temp = torch.zeros(self.channels, device=vec.device)
847
+ temp[:singulars.size(0)] = singulars
848
+ singulars = temp
849
+ inverse_singulars = 1. / singulars
850
+ inverse_singulars[singulars == 0] = 0.
851
+
852
+ if a != 0 and sigma_y != 0:
853
+ change_index = (sigma_t < a * sigma_y * inverse_singulars) * 1.0
854
+ d1_t = d1_t * (-change_index + 1.0) + change_index * sigma_t * eta
855
+ d2_t = d2_t * (-change_index + 1.0)
856
+
857
+ change_index = (sigma_t > a * sigma_y * inverse_singulars) * 1.0
858
+ d1_t = d1_t * (-change_index + 1.0) + torch.sqrt(
859
+ change_index * (sigma_t ** 2 - a ** 2 * sigma_y ** 2 * inverse_singulars ** 2))
860
+ d2_t = d2_t * (-change_index + 1.0)
861
+
862
+ change_index = (singulars == 0) * 1.0
863
+ d1_t = d1_t * (-change_index + 1.0) + change_index * sigma_t * eta
864
+ d2_t = d2_t * (-change_index + 1.0) + change_index * sigma_t * (1 - eta ** 2) ** 0.5
865
+
866
+ d1_t = d1_t.reshape(1, 1, self.channels)
867
+ d2_t = d2_t.reshape(1, 1, self.channels)
868
+
869
+ needles_vec = needles_vec * d1_t
870
+ needles_epsilon = needles_epsilon * d2_t
871
+
872
+ needles_vec = torch.matmul(self.V_small, needles_vec.reshape(-1, self.channels, 1)).reshape(vec.shape[0], -1, self.channels)
873
+ recon_vec = needles_vec.permute(0, 2, 1).reshape(vec.shape[0], -1)
874
+
875
+ needles_epsilon = torch.matmul(self.V_small, needles_epsilon.reshape(-1, self.channels, 1)).reshape(vec.shape[0], -1,self.channels)
876
+ recon_epsilon = needles_epsilon.permute(0, 2, 1).reshape(vec.shape[0], -1)
877
+
878
+ return recon_vec + recon_epsilon
879
+
880
+ #Walsh-Aadamard Compressive Sensing
881
+ class WalshAadamardCS(A_functions):
882
+ def fwht(self, vec): #the Fast Walsh Aadamard Transform is the same as its inverse
883
+ a = vec.reshape(vec.shape[0], self.channels, self.img_dim**2)
884
+ h = 1
885
+ while h < self.img_dim**2:
886
+ a = a.reshape(vec.shape[0], self.channels, -1, h * 2)
887
+ b = a.clone()
888
+ a[:, :, :, :h] = b[:, :, :, :h] + b[:, :, :, h:2*h]
889
+ a[:, :, :, h:2*h] = b[:, :, :, :h] - b[:, :, :, h:2*h]
890
+ h *= 2
891
+ a = a.reshape(vec.shape[0], self.channels, self.img_dim**2) / self.img_dim
892
+ return a
893
+
894
+ def __init__(self, channels, img_dim, ratio, perm, device):
895
+ self.channels = channels
896
+ self.img_dim = img_dim
897
+ self.ratio = ratio
898
+ self.perm = perm
899
+ self._singulars = torch.ones(channels * img_dim**2 // ratio, device=device)
900
+
901
+ def V(self, vec):
902
+ temp = torch.zeros(vec.shape[0], self.channels, self.img_dim**2, device=vec.device)
903
+ temp[:, :, self.perm] = vec.clone().reshape(vec.shape[0], -1, self.channels).permute(0, 2, 1)
904
+ return self.fwht(temp).reshape(vec.shape[0], -1)
905
+
906
+ def Vt(self, vec):
907
+ return self.fwht(vec.clone())[:, :, self.perm].permute(0, 2, 1).reshape(vec.shape[0], -1)
908
+
909
+ def U(self, vec):
910
+ return vec.clone().reshape(vec.shape[0], -1)
911
+
912
+ def Ut(self, vec):
913
+ return vec.clone().reshape(vec.shape[0], -1)
914
+
915
+ def singulars(self):
916
+ return self._singulars
917
+
918
+ def add_zeros(self, vec):
919
+ out = torch.zeros(vec.shape[0], self.channels * self.img_dim**2, device=vec.device)
920
+ out[:, :self.channels * self.img_dim**2 // self.ratio] = vec.clone().reshape(vec.shape[0], -1)
921
+ return out
922
+
923
+ def Lambda(self, vec, a, sigma_y, sigma_t, eta):
924
+ temp_vec = self.fwht(vec.clone())[:, :, self.perm].permute(0, 2, 1).reshape(vec.shape[0], -1)
925
+
926
+ singulars = self._singulars
927
+ lambda_t = torch.ones(self.channels * self.img_dim ** 2, device=vec.device)
928
+ temp = torch.zeros(self.channels * self.img_dim ** 2, device=vec.device)
929
+ temp[:singulars.size(0)] = singulars
930
+ singulars = temp
931
+ inverse_singulars = 1. / singulars
932
+ inverse_singulars[singulars == 0] = 0.
933
+
934
+ if a != 0 and sigma_y != 0:
935
+ change_index = (sigma_t < a * sigma_y * inverse_singulars) * 1.0
936
+ lambda_t = lambda_t * (-change_index + 1.0) + change_index * (
937
+ singulars * sigma_t * (1 - eta ** 2) ** 0.5 / a / sigma_y)
938
+
939
+ lambda_t = lambda_t.reshape(1, -1)
940
+ temp_vec = temp_vec * lambda_t
941
+
942
+ temp_out = torch.zeros(vec.shape[0], self.channels, self.img_dim ** 2, device=vec.device)
943
+ temp_out[:, :, self.perm] = temp_vec.clone().reshape(vec.shape[0], -1, self.channels).permute(0, 2, 1)
944
+ return self.fwht(temp_out).reshape(vec.shape[0], -1)
945
+
946
+ def Lambda_noise(self, vec, a, sigma_y, sigma_t, eta, epsilon):
947
+ temp_vec = vec.clone().reshape(
948
+ vec.shape[0], self.channels, self.img_dim ** 2)[:, :, self.perm].permute(0, 2, 1).reshape(vec.shape[0], -1)
949
+ temp_eps = epsilon.clone().reshape(
950
+ vec.shape[0], self.channels, self.img_dim ** 2)[:, :, self.perm].permute(0, 2, 1).reshape(vec.shape[0], -1)
951
+
952
+ d1_t = torch.ones(self.channels * self.img_dim ** 2, device=vec.device) * sigma_t * eta
953
+ d2_t = torch.ones(self.channels * self.img_dim ** 2, device=vec.device) * sigma_t * (1 - eta ** 2) ** 0.5
954
+
955
+ singulars = self._singulars
956
+ temp = torch.zeros(self.channels * self.img_dim ** 2, device=vec.device)
957
+ temp[:singulars.size(0)] = singulars
958
+ singulars = temp
959
+ inverse_singulars = 1. / singulars
960
+ inverse_singulars[singulars == 0] = 0.
961
+
962
+ if a != 0 and sigma_y != 0:
963
+ change_index = (sigma_t < a * sigma_y * inverse_singulars) * 1.0
964
+ d1_t = d1_t * (-change_index + 1.0) + change_index * sigma_t * eta
965
+ d2_t = d2_t * (-change_index + 1.0)
966
+
967
+ change_index = (sigma_t > a * sigma_y * inverse_singulars) * 1.0
968
+ d1_t = d1_t * (-change_index + 1.0) + torch.sqrt(
969
+ change_index * (sigma_t ** 2 - a ** 2 * sigma_y ** 2 * inverse_singulars ** 2))
970
+ d2_t = d2_t * (-change_index + 1.0)
971
+
972
+ change_index = (singulars == 0) * 1.0
973
+ d1_t = d1_t * (-change_index + 1.0) + change_index * sigma_t * eta
974
+ d2_t = d2_t * (-change_index + 1.0) + change_index * sigma_t * (1 - eta ** 2) ** 0.5
975
+
976
+ d1_t = d1_t.reshape(1, -1)
977
+ d2_t = d2_t.reshape(1, -1)
978
+
979
+ temp_vec = temp_vec * d1_t
980
+ temp_eps = temp_eps * d2_t
981
+
982
+ temp_out_vec = torch.zeros(vec.shape[0], self.channels, self.img_dim ** 2, device=vec.device)
983
+ temp_out_vec[:, :, self.perm] = temp_vec.clone().reshape(vec.shape[0], -1, self.channels).permute(0, 2, 1)
984
+ temp_out_vec = self.fwht(temp_out_vec).reshape(vec.shape[0], -1)
985
+
986
+ temp_out_eps = torch.zeros(vec.shape[0], self.channels, self.img_dim ** 2, device=vec.device)
987
+ temp_out_eps[:, :, self.perm] = temp_eps.clone().reshape(vec.shape[0], -1, self.channels).permute(0, 2, 1)
988
+ temp_out_eps = self.fwht(temp_out_eps).reshape(vec.shape[0], -1)
989
+
990
+ return temp_out_vec + temp_out_eps
991
+
992
+ #Convolution-based super-resolution
993
+ class SRConv(A_functions):
994
+ def mat_by_img(self, M, v, dim):
995
+ return torch.matmul(M, v.reshape(v.shape[0] * self.channels, dim,
996
+ dim)).reshape(v.shape[0], self.channels, M.shape[0], dim)
997
+
998
+ def img_by_mat(self, v, M, dim):
999
+ return torch.matmul(v.reshape(v.shape[0] * self.channels, dim,
1000
+ dim), M).reshape(v.shape[0], self.channels, dim, M.shape[1])
1001
+
1002
+ def __init__(self, kernel, channels, img_dim, device, stride = 1):
1003
+ self.img_dim = img_dim
1004
+ self.channels = channels
1005
+ self.ratio = stride
1006
+ small_dim = img_dim // stride
1007
+ self.small_dim = small_dim
1008
+ #build 1D conv matrix
1009
+ A_small = torch.zeros(small_dim, img_dim, device=device)
1010
+ for i in range(stride//2, img_dim + stride//2, stride):
1011
+ for j in range(i - kernel.shape[0]//2, i + kernel.shape[0]//2):
1012
+ j_effective = j
1013
+ #reflective padding
1014
+ if j_effective < 0: j_effective = -j_effective-1
1015
+ if j_effective >= img_dim: j_effective = (img_dim - 1) - (j_effective - img_dim)
1016
+ #matrix building
1017
+ A_small[i // stride, j_effective] += kernel[j - i + kernel.shape[0]//2]
1018
+ #get the svd of the 1D conv
1019
+ self.U_small, self.singulars_small, self.V_small = torch.svd(A_small, some=False)
1020
+ ZERO = 3e-2
1021
+ self.singulars_small[self.singulars_small < ZERO] = 0
1022
+ #calculate the singular values of the big matrix
1023
+ self._singulars = torch.matmul(self.singulars_small.reshape(small_dim, 1), self.singulars_small.reshape(1, small_dim)).reshape(small_dim**2)
1024
+ #permutation for matching the singular values. See P_1 in Appendix D.5.
1025
+ self._perm = torch.Tensor([self.img_dim * i + j for i in range(self.small_dim) for j in range(self.small_dim)] + \
1026
+ [self.img_dim * i + j for i in range(self.small_dim) for j in range(self.small_dim, self.img_dim)]).to(device).long()
1027
+
1028
+ def V(self, vec):
1029
+ #invert the permutation
1030
+ temp = torch.zeros(vec.shape[0], self.img_dim**2, self.channels, device=vec.device)
1031
+ temp[:, self._perm, :] = vec.clone().reshape(vec.shape[0], self.img_dim**2, self.channels)[:, :self._perm.shape[0], :]
1032
+ temp[:, self._perm.shape[0]:, :] = vec.clone().reshape(vec.shape[0], self.img_dim**2, self.channels)[:, self._perm.shape[0]:, :]
1033
+ temp = temp.permute(0, 2, 1)
1034
+ #multiply the image by V from the left and by V^T from the right
1035
+ out = self.mat_by_img(self.V_small, temp, self.img_dim)
1036
+ out = self.img_by_mat(out, self.V_small.transpose(0, 1), self.img_dim).reshape(vec.shape[0], -1)
1037
+ return out
1038
+
1039
+ def Vt(self, vec):
1040
+ #multiply the image by V^T from the left and by V from the right
1041
+ temp = self.mat_by_img(self.V_small.transpose(0, 1), vec.clone(), self.img_dim)
1042
+ temp = self.img_by_mat(temp, self.V_small, self.img_dim).reshape(vec.shape[0], self.channels, -1)
1043
+ #permute the entries
1044
+ temp[:, :, :self._perm.shape[0]] = temp[:, :, self._perm]
1045
+ temp = temp.permute(0, 2, 1)
1046
+ return temp.reshape(vec.shape[0], -1)
1047
+
1048
+ def U(self, vec):
1049
+ #invert the permutation
1050
+ temp = torch.zeros(vec.shape[0], self.small_dim**2, self.channels, device=vec.device)
1051
+ temp[:, :self.small_dim**2, :] = vec.clone().reshape(vec.shape[0], self.small_dim**2, self.channels)
1052
+ temp = temp.permute(0, 2, 1)
1053
+ #multiply the image by U from the left and by U^T from the right
1054
+ out = self.mat_by_img(self.U_small, temp, self.small_dim)
1055
+ out = self.img_by_mat(out, self.U_small.transpose(0, 1), self.small_dim).reshape(vec.shape[0], -1)
1056
+ return out
1057
+
1058
+ def Ut(self, vec):
1059
+ #multiply the image by U^T from the left and by U from the right
1060
+ temp = self.mat_by_img(self.U_small.transpose(0, 1), vec.clone(), self.small_dim)
1061
+ temp = self.img_by_mat(temp, self.U_small, self.small_dim).reshape(vec.shape[0], self.channels, -1)
1062
+ #permute the entries
1063
+ temp = temp.permute(0, 2, 1)
1064
+ return temp.reshape(vec.shape[0], -1)
1065
+
1066
+ def singulars(self):
1067
+ return self._singulars.repeat_interleave(3).reshape(-1)
1068
+
1069
+ def add_zeros(self, vec):
1070
+ reshaped = vec.clone().reshape(vec.shape[0], -1)
1071
+ temp = torch.zeros((vec.shape[0], reshaped.shape[1] * self.ratio**2), device=vec.device)
1072
+ temp[:, :reshaped.shape[1]] = reshaped
1073
+ return temp
1074
+
1075
+ #Deblurring
1076
+ class Deblurring(A_functions):
1077
+ def mat_by_img(self, M, v):
1078
+ return torch.matmul(M, v.reshape(v.shape[0] * self.channels, self.img_dim,
1079
+ self.img_dim)).reshape(v.shape[0], self.channels, M.shape[0], self.img_dim)
1080
+
1081
+ def img_by_mat(self, v, M):
1082
+ return torch.matmul(v.reshape(v.shape[0] * self.channels, self.img_dim,
1083
+ self.img_dim), M).reshape(v.shape[0], self.channels, self.img_dim, M.shape[1])
1084
+
1085
+ def __init__(self, kernel, channels, img_dim, device, ZERO = 3e-2):
1086
+ self.img_dim = img_dim
1087
+ self.channels = channels
1088
+ #build 1D conv matrix
1089
+ A_small = torch.zeros(img_dim, img_dim, device=device)
1090
+ for i in range(img_dim):
1091
+ for j in range(i - kernel.shape[0]//2, i + kernel.shape[0]//2):
1092
+ if j < 0 or j >= img_dim: continue
1093
+ A_small[i, j] = kernel[j - i + kernel.shape[0]//2]
1094
+ #get the svd of the 1D conv
1095
+ self.U_small, self.singulars_small, self.V_small = torch.svd(A_small, some=False)
1096
+ #ZERO = 3e-2
1097
+ self.singulars_small_orig = self.singulars_small.clone()
1098
+ self.singulars_small[self.singulars_small < ZERO] = 0
1099
+ #calculate the singular values of the big matrix
1100
+ self._singulars_orig = torch.matmul(self.singulars_small_orig.reshape(img_dim, 1), self.singulars_small_orig.reshape(1, img_dim)).reshape(img_dim**2)
1101
+ self._singulars = torch.matmul(self.singulars_small.reshape(img_dim, 1), self.singulars_small.reshape(1, img_dim)).reshape(img_dim**2)
1102
+ #sort the big matrix singulars and save the permutation
1103
+ self._singulars, self._perm = self._singulars.sort(descending=True) #, stable=True)
1104
+ self._singulars_orig = self._singulars_orig[self._perm]
1105
+
1106
+ def V(self, vec):
1107
+ #invert the permutation
1108
+ temp = torch.zeros(vec.shape[0], self.img_dim**2, self.channels, device=vec.device)
1109
+ temp[:, self._perm, :] = vec.clone().reshape(vec.shape[0], self.img_dim**2, self.channels)
1110
+ temp = temp.permute(0, 2, 1)
1111
+ #multiply the image by V from the left and by V^T from the right
1112
+ out = self.mat_by_img(self.V_small, temp)
1113
+ out = self.img_by_mat(out, self.V_small.transpose(0, 1)).reshape(vec.shape[0], -1)
1114
+ return out
1115
+
1116
+ def Vt(self, vec):
1117
+ #multiply the image by V^T from the left and by V from the right
1118
+ temp = self.mat_by_img(self.V_small.transpose(0, 1), vec.clone())
1119
+ temp = self.img_by_mat(temp, self.V_small).reshape(vec.shape[0], self.channels, -1)
1120
+ #permute the entries according to the singular values
1121
+ temp = temp[:, :, self._perm].permute(0, 2, 1)
1122
+ return temp.reshape(vec.shape[0], -1)
1123
+
1124
+ def U(self, vec):
1125
+ #invert the permutation
1126
+ temp = torch.zeros(vec.shape[0], self.img_dim**2, self.channels, device=vec.device)
1127
+ temp[:, self._perm, :] = vec.clone().reshape(vec.shape[0], self.img_dim**2, self.channels)
1128
+ temp = temp.permute(0, 2, 1)
1129
+ #multiply the image by U from the left and by U^T from the right
1130
+ out = self.mat_by_img(self.U_small, temp)
1131
+ out = self.img_by_mat(out, self.U_small.transpose(0, 1)).reshape(vec.shape[0], -1)
1132
+ return out
1133
+
1134
+ def Ut(self, vec):
1135
+ #multiply the image by U^T from the left and by U from the right
1136
+ temp = self.mat_by_img(self.U_small.transpose(0, 1), vec.clone())
1137
+ temp = self.img_by_mat(temp, self.U_small).reshape(vec.shape[0], self.channels, -1)
1138
+ #permute the entries according to the singular values
1139
+ temp = temp[:, :, self._perm].permute(0, 2, 1)
1140
+ return temp.reshape(vec.shape[0], -1)
1141
+
1142
+ def singulars(self):
1143
+ return self._singulars.repeat(1, 3).reshape(-1)
1144
+
1145
+ def add_zeros(self, vec):
1146
+ return vec.clone().reshape(vec.shape[0], -1)
1147
+
1148
+ def A_pinv(self, vec):
1149
+ temp = self.Ut(vec)
1150
+ singulars = self._singulars.repeat(1, 3).reshape(-1)
1151
+
1152
+ factors = 1. / singulars
1153
+ factors[singulars == 0] = 0.
1154
+
1155
+ temp[:, :singulars.shape[0]] = temp[:, :singulars.shape[0]] * factors
1156
+ return self.V(self.add_zeros(temp))
1157
+
1158
+ def Lambda(self, vec, a, sigma_y, sigma_t, eta):
1159
+ temp_vec = self.mat_by_img(self.V_small.transpose(0, 1), vec.clone())
1160
+ temp_vec = self.img_by_mat(temp_vec, self.V_small).reshape(vec.shape[0], self.channels, -1)
1161
+ temp_vec = temp_vec[:, :, self._perm].permute(0, 2, 1)
1162
+
1163
+ singulars = self._singulars_orig
1164
+ lambda_t = torch.ones(self.img_dim ** 2, device=vec.device)
1165
+ temp_singulars = torch.zeros(self.img_dim ** 2, device=vec.device)
1166
+ temp_singulars[:singulars.size(0)] = singulars
1167
+ singulars = temp_singulars
1168
+ inverse_singulars = 1. / singulars
1169
+ inverse_singulars[singulars == 0] = 0.
1170
+
1171
+ if a != 0 and sigma_y != 0:
1172
+ change_index = (sigma_t < a * sigma_y * inverse_singulars) * 1.0
1173
+ lambda_t = lambda_t * (-change_index + 1.0) + change_index * (
1174
+ singulars * sigma_t * (1 - eta ** 2) ** 0.5 / a / sigma_y)
1175
+
1176
+ lambda_t = lambda_t.reshape(1, -1, 1)
1177
+ temp_vec = temp_vec * lambda_t
1178
+
1179
+ temp = torch.zeros(vec.shape[0], self.img_dim ** 2, self.channels, device=vec.device)
1180
+ temp[:, self._perm, :] = temp_vec.clone().reshape(vec.shape[0], self.img_dim ** 2, self.channels)
1181
+ temp = temp.permute(0, 2, 1)
1182
+ out = self.mat_by_img(self.V_small, temp)
1183
+ out = self.img_by_mat(out, self.V_small.transpose(0, 1)).reshape(vec.shape[0], -1)
1184
+ return out
1185
+
1186
+ def Lambda_noise(self, vec, a, sigma_y, sigma_t, eta, epsilon):
1187
+ temp_vec = vec.clone().reshape(vec.shape[0], self.channels, -1)
1188
+ temp_vec = temp_vec[:, :, self._perm].permute(0, 2, 1)
1189
+
1190
+ temp_eps = epsilon.clone().reshape(vec.shape[0], self.channels, -1)
1191
+ temp_eps = temp_eps[:, :, self._perm].permute(0, 2, 1)
1192
+
1193
+ singulars = self._singulars_orig
1194
+ d1_t = torch.ones(self.img_dim ** 2, device=vec.device) * sigma_t * eta
1195
+ d2_t = torch.ones(self.img_dim ** 2, device=vec.device) * sigma_t * (1 - eta ** 2) ** 0.5
1196
+
1197
+ temp_singulars = torch.zeros(self.img_dim ** 2, device=vec.device)
1198
+ temp_singulars[:singulars.size(0)] = singulars
1199
+ singulars = temp_singulars
1200
+ inverse_singulars = 1. / singulars
1201
+ inverse_singulars[singulars == 0] = 0.
1202
+
1203
+ if a != 0 and sigma_y != 0:
1204
+ change_index = (sigma_t < a * sigma_y * inverse_singulars) * 1.0
1205
+ d1_t = d1_t * (-change_index + 1.0) + change_index * sigma_t * eta
1206
+ d2_t = d2_t * (-change_index + 1.0)
1207
+
1208
+ change_index = (sigma_t > a * sigma_y * inverse_singulars) * 1.0
1209
+ d1_t = d1_t * (-change_index + 1.0) + torch.sqrt(
1210
+ change_index * (sigma_t ** 2 - a ** 2 * sigma_y ** 2 * inverse_singulars ** 2))
1211
+ d2_t = d2_t * (-change_index + 1.0)
1212
+
1213
+ change_index = (singulars == 0) * 1.0
1214
+ d1_t = d1_t * (-change_index + 1.0) + change_index * sigma_t * eta
1215
+ d2_t = d2_t * (-change_index + 1.0) + change_index * sigma_t * (1 - eta ** 2) ** 0.5
1216
+
1217
+ d1_t = d1_t.reshape(1, -1, 1)
1218
+ d2_t = d2_t.reshape(1, -1, 1)
1219
+
1220
+ temp_vec = temp_vec * d1_t
1221
+ temp_eps = temp_eps * d2_t
1222
+
1223
+ temp_vec_new = torch.zeros(vec.shape[0], self.img_dim ** 2, self.channels, device=vec.device)
1224
+ temp_vec_new[:, self._perm, :] = temp_vec
1225
+ out_vec = self.mat_by_img(self.V_small, temp_vec_new.permute(0, 2, 1))
1226
+ out_vec = self.img_by_mat(out_vec, self.V_small.transpose(0, 1)).reshape(vec.shape[0], -1)
1227
+
1228
+ temp_eps_new = torch.zeros(vec.shape[0], self.img_dim ** 2, self.channels, device=vec.device)
1229
+ temp_eps_new[:, self._perm, :] = temp_eps
1230
+ out_eps = self.mat_by_img(self.V_small, temp_eps_new.permute(0, 2, 1))
1231
+ out_eps = self.img_by_mat(out_eps, self.V_small.transpose(0, 1)).reshape(vec.shape[0], -1)
1232
+
1233
+ return out_vec + out_eps
1234
+
1235
+ #Anisotropic Deblurring
1236
+ class Deblurring2D(A_functions):
1237
+ def mat_by_img(self, M, v):
1238
+ return torch.matmul(M, v.reshape(v.shape[0] * self.channels, self.img_dim,
1239
+ self.img_dim)).reshape(v.shape[0], self.channels, M.shape[0], self.img_dim)
1240
+
1241
+ def img_by_mat(self, v, M):
1242
+ return torch.matmul(v.reshape(v.shape[0] * self.channels, self.img_dim,
1243
+ self.img_dim), M).reshape(v.shape[0], self.channels, self.img_dim, M.shape[1])
1244
+
1245
+ def __init__(self, kernel1, kernel2, channels, img_dim, device):
1246
+ self.img_dim = img_dim
1247
+ self.channels = channels
1248
+ A_small1 = torch.zeros(img_dim, img_dim, device=device)
1249
+ for i in range(img_dim):
1250
+ for j in range(i - kernel1.shape[0]//2, i + kernel1.shape[0]//2):
1251
+ if j < 0 or j >= img_dim: continue
1252
+ A_small1[i, j] = kernel1[j - i + kernel1.shape[0]//2]
1253
+ A_small2 = torch.zeros(img_dim, img_dim, device=device)
1254
+ for i in range(img_dim):
1255
+ for j in range(i - kernel2.shape[0]//2, i + kernel2.shape[0]//2):
1256
+ if j < 0 or j >= img_dim: continue
1257
+ A_small2[i, j] = kernel2[j - i + kernel2.shape[0]//2]
1258
+ self.U_small1, self.singulars_small1, self.V_small1 = torch.svd(A_small1, some=False)
1259
+ self.U_small2, self.singulars_small2, self.V_small2 = torch.svd(A_small2, some=False)
1260
+ ZERO = 3e-2
1261
+ self.singulars_small1[self.singulars_small1 < ZERO] = 0
1262
+ self.singulars_small2[self.singulars_small2 < ZERO] = 0
1263
+ #calculate the singular values of the big matrix
1264
+ self._singulars = torch.matmul(self.singulars_small1.reshape(img_dim, 1), self.singulars_small2.reshape(1, img_dim)).reshape(img_dim**2)
1265
+ #sort the big matrix singulars and save the permutation
1266
+ self._singulars, self._perm = self._singulars.sort(descending=True) #, stable=True)
1267
+
1268
+ def V(self, vec):
1269
+ #invert the permutation
1270
+ temp = torch.zeros(vec.shape[0], self.img_dim**2, self.channels, device=vec.device)
1271
+ temp[:, self._perm, :] = vec.clone().reshape(vec.shape[0], self.img_dim**2, self.channels)
1272
+ temp = temp.permute(0, 2, 1)
1273
+ #multiply the image by V from the left and by V^T from the right
1274
+ out = self.mat_by_img(self.V_small1, temp)
1275
+ out = self.img_by_mat(out, self.V_small2.transpose(0, 1)).reshape(vec.shape[0], -1)
1276
+ return out
1277
+
1278
+ def Vt(self, vec):
1279
+ #multiply the image by V^T from the left and by V from the right
1280
+ temp = self.mat_by_img(self.V_small1.transpose(0, 1), vec.clone())
1281
+ temp = self.img_by_mat(temp, self.V_small2).reshape(vec.shape[0], self.channels, -1)
1282
+ #permute the entries according to the singular values
1283
+ temp = temp[:, :, self._perm].permute(0, 2, 1)
1284
+ return temp.reshape(vec.shape[0], -1)
1285
+
1286
+ def U(self, vec):
1287
+ #invert the permutation
1288
+ temp = torch.zeros(vec.shape[0], self.img_dim**2, self.channels, device=vec.device)
1289
+ temp[:, self._perm, :] = vec.clone().reshape(vec.shape[0], self.img_dim**2, self.channels)
1290
+ temp = temp.permute(0, 2, 1)
1291
+ #multiply the image by U from the left and by U^T from the right
1292
+ out = self.mat_by_img(self.U_small1, temp)
1293
+ out = self.img_by_mat(out, self.U_small2.transpose(0, 1)).reshape(vec.shape[0], -1)
1294
+ return out
1295
+
1296
+ def Ut(self, vec):
1297
+ #multiply the image by U^T from the left and by U from the right
1298
+ temp = self.mat_by_img(self.U_small1.transpose(0, 1), vec.clone())
1299
+ temp = self.img_by_mat(temp, self.U_small2).reshape(vec.shape[0], self.channels, -1)
1300
+ #permute the entries according to the singular values
1301
+ temp = temp[:, :, self._perm].permute(0, 2, 1)
1302
+ return temp.reshape(vec.shape[0], -1)
1303
+
1304
+ def singulars(self):
1305
+ return self._singulars.repeat(1, 3).reshape(-1)
1306
+
1307
+ def add_zeros(self, vec):
1308
+ return vec.clone().reshape(vec.shape[0], -1)
src/flair/helper_functions.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yaml
3
+ import click
4
+ import subprocess
5
+
6
+ def parse_click_context(ctx):
7
+ """Parse additional arguments passed via Click context."""
8
+ extra_args = {}
9
+ for arg in ctx.args:
10
+ if "=" in arg:
11
+ key, value = arg.split("=", 1)
12
+ if key.startswith("--"):
13
+ key = key[2:] # Remove leading "--"
14
+ extra_args[key] = yaml.safe_load(value)
15
+ return extra_args
16
+
17
+ def generate_captions_with_seesr(pseudo_inv_dir, output_caption_file):
18
+ """Generate captions using the SEESR model."""
19
+ command = [
20
+ "conda",
21
+ "run",
22
+ "-n",
23
+ "seesr",
24
+ "python",
25
+ "/home/erbachj/scratch2/projects/var_post_samp/scripts/generate_caption.py",
26
+ "--input_dir",
27
+ pseudo_inv_dir,
28
+ "--output_file",
29
+ output_caption_file, # Corrected argument name
30
+ ]
31
+ subprocess.run(command, check=True)
src/flair/pipelines/__init__.py ADDED
File without changes
src/flair/pipelines/model_loader.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # functions to load models from config
2
+ import numpy as np
3
+ import torch
4
+ import re
5
+ import os
6
+ import math
7
+ from diffusers import (
8
+ BitsAndBytesConfig
9
+ )
10
+ from diffusers import AutoencoderTiny
11
+
12
+
13
+ from flair.pipelines import sd3
14
+
15
+
16
+
17
+ @torch.no_grad()
18
+ def load_sd3(config, device):
19
+ if isinstance(device, list):
20
+ device = device[0]
21
+ if config["quantize"]:
22
+ nf4_config = BitsAndBytesConfig(
23
+ load_in_4bit=True,
24
+ bnb_4bit_quant_type="nf4",
25
+ bnb_4bit_compute_dtype=torch.bfloat16
26
+ )
27
+ else:
28
+ nf4_config = None
29
+ if config["model"] == "SD3.5-large":
30
+ pipe = sd3.SD3Wrapper.from_pretrained(
31
+ "stabilityai/stable-diffusion-3.5-large", torch_dtype=torch.bfloat16, quantization_config=nf4_config
32
+ )
33
+ elif config["model"] == "SD3.5-large-turbo":
34
+ pipe = sd3.SD3Wrapper.from_pretrained(
35
+ "stabilityai/stable-diffusion-3.5-large-turbo", torch_dtype=torch.bfloat16, quantization_config=nf4_config,
36
+ )
37
+ else:
38
+ pipe = sd3.SD3Wrapper.from_pretrained("stabilityai/stable-diffusion-3.5-medium", torch_dtype=torch.float16, quantization_config=nf4_config)
39
+ # maybe use tiny autoencoder
40
+ if config["use_tiny_ae"]:
41
+ pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesd3", torch_dtype=torch.float16)
42
+
43
+ # encode prompts
44
+ inp_kwargs_list = []
45
+ prompts = config["prompt"]
46
+ pipe._guidance_scale = config["guidance"]
47
+ pipe._joint_attention_kwargs = {"ip_adapter_image_embeds": None}
48
+ for prompt in prompts:
49
+ print(f"Generating prompt embeddings for: {prompt}")
50
+ pipe.text_encoder.to(device).to(torch.bfloat16)
51
+ pipe.text_encoder_2.to(device).to(torch.bfloat16)
52
+ pipe.text_encoder_3.to(device).to(torch.bfloat16)
53
+ # encode
54
+ (
55
+ prompt_embeds,
56
+ negative_prompt_embeds,
57
+ pooled_prompt_embeds,
58
+ negative_pooled_prompt_embeds,
59
+ ) = pipe.encode_prompt(
60
+ prompt=prompt,
61
+ prompt_2=prompt,
62
+ prompt_3=prompt,
63
+ negative_prompt=config["negative_prompt"],
64
+ negative_prompt_2=config["negative_prompt"],
65
+ negative_prompt_3=config["negative_prompt"],
66
+ do_classifier_free_guidance=pipe.do_classifier_free_guidance,
67
+ prompt_embeds=None,
68
+ negative_prompt_embeds=None,
69
+ pooled_prompt_embeds=None,
70
+ negative_pooled_prompt_embeds=None,
71
+ device=device,
72
+ clip_skip=None,
73
+ num_images_per_prompt=1,
74
+ max_sequence_length=256,
75
+ lora_scale=None,
76
+ )
77
+
78
+ if pipe.do_classifier_free_guidance:
79
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
80
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
81
+ inp_kwargs = {
82
+ "prompt_embeds": prompt_embeds,
83
+ "pooled_prompt_embeds": pooled_prompt_embeds,
84
+ "guidance": config["guidance"],
85
+ }
86
+ inp_kwargs_list.append(inp_kwargs)
87
+ pipe.vae.to(device).to(torch.bfloat16)
88
+ pipe.transformer.to(device).to(torch.bfloat16)
89
+
90
+
91
+ return pipe, inp_kwargs_list
92
+
93
+ def load_model(config, device=["cuda"]):
94
+ if re.match(r"SD3*", config["model"]):
95
+ return load_sd3(config, device)
96
+ else:
97
+ raise ValueError(f"Unknown model type {config['model']}")
src/flair/pipelines/sd3.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Dict, Any
3
+ from diffusers.pipelines.stable_diffusion_3 import pipeline_stable_diffusion_3
4
+ from flair.pipelines import utils
5
+ import tqdm
6
+
7
+ class SD3Wrapper(pipeline_stable_diffusion_3.StableDiffusion3Pipeline):
8
+ def to(self, device, kwargs):
9
+ self.transformer.to(device)
10
+ self.vae.to(device)
11
+ return self
12
+
13
+ def get_timesteps(self, n_steps, device, ts_min=0):
14
+ # Create a linear schedule for timesteps
15
+ timesteps = torch.linspace(1, ts_min, n_steps+2, device=device, dtype=torch.float32)
16
+ return timesteps[1:-1] # Exclude the first and last timesteps
17
+
18
+ def single_step(
19
+ self,
20
+ img_latent: torch.Tensor,
21
+ t: torch.Tensor,
22
+ kwargs: Dict[str, Any],
23
+ is_noised_latent = False,
24
+ ):
25
+ if "noise" in kwargs:
26
+ noise = kwargs["noise"].detach()
27
+ alpha = kwargs["inv_alpha"]
28
+ if alpha == "tsqrt":
29
+ alpha = t**0.5 # * 0.75
30
+ elif alpha == "t":
31
+ alpha = t
32
+ elif alpha == "sine":
33
+ alpha = torch.sin(t * 3.141592653589793/2)
34
+ elif alpha == "1-t":
35
+ alpha = 1 - t
36
+ elif alpha == "1-t*0.5":
37
+ alpha = (1 - t)*0.5
38
+ elif alpha == "1-t*0.9":
39
+ alpha = (1 - t)*0.9
40
+ elif alpha == "t**1/3":
41
+ alpha = t**(1/3)
42
+ elif alpha == "(1-t)**0.5":
43
+ alpha = (1-t)**0.5
44
+ elif alpha == "((1-t)*0.8)**0.5":
45
+ alpha = (1-t*0.8)**0.5
46
+ elif alpha == "(1-t)**2":
47
+ alpha = (1-t)**2
48
+ # alpha = t * kwargs["inv_alpha"]
49
+ noise = (alpha) * noise + (1-alpha**2)**0.5 * torch.randn_like(img_latent)
50
+ # noise = noise / noise.std()
51
+ # noise = noise / (1- 2*alpha*(1-alpha))**0.5
52
+ # noise = noise + alpha * torch.randn_like(img_latent)
53
+ else:
54
+ noise = torch.randn_like(img_latent)
55
+ if is_noised_latent:
56
+ noised_latent = img_latent
57
+ else:
58
+ noised_latent = t * noise + (1 - t) * img_latent
59
+ latent_model_input = torch.cat([noised_latent] * 2) if self.do_classifier_free_guidance else noised_latent
60
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
61
+ timestep = t.expand(latent_model_input.shape[0])
62
+ noise_pred = self.transformer(
63
+ hidden_states=latent_model_input.to(img_latent.dtype),
64
+ timestep=(timestep*1000).to(img_latent.dtype),
65
+ encoder_hidden_states=kwargs["prompt_embeds"].repeat(img_latent.shape[0], 1, 1),
66
+ pooled_projections=kwargs["pooled_prompt_embeds"].repeat(img_latent.shape[0], 1),
67
+ joint_attention_kwargs=None,
68
+ return_dict=False,
69
+ )[0]
70
+ if self.do_classifier_free_guidance:
71
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
72
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
73
+
74
+ eps = utils.v_to_eps(noise_pred, t, noised_latent)
75
+ return eps, noise, (1 - t), t, noise_pred
76
+
77
+ def encode(self, img):
78
+ # Encode the image into latent space
79
+
80
+ img_latent = self.vae.encode(img, return_dict=False)[0]
81
+ if hasattr(img_latent, "sample"):
82
+ img_latent = img_latent.sample()
83
+ img_latent = (img_latent - self.vae.config.shift_factor) * self.vae.config.scaling_factor
84
+ return img_latent
85
+
86
+ def decode(self, img_latent):
87
+ # Decode the latent representation back to image space
88
+ img = self.vae.decode(img_latent / self.vae.config.scaling_factor + self.vae.config.shift_factor, return_dict=False)[0]
89
+ return img
90
+
91
+ def denoise(self, pseudo_inv, kwargs, inverse=False):
92
+ # get timesteps
93
+ timesteps = torch.linspace(1, 0, kwargs["n_steps"], device=pseudo_inv.device, dtype=pseudo_inv.dtype)
94
+ sigmas = timesteps
95
+ if inverse:
96
+ timesteps = timesteps.flip(0)
97
+ sigmas = sigmas.flip(0)
98
+
99
+ # make a single step
100
+ for i, t in tqdm.tqdm(enumerate(timesteps[:-1]), desc="Denoising", total=len(timesteps)-1):
101
+ eps, noise, _, t, v = self.single_step(
102
+ pseudo_inv,
103
+ t.to("cuda")*1000,
104
+ kwargs,
105
+ is_noised_latent=True,
106
+ )
107
+ # step
108
+ sigma_next = sigmas[i+1]
109
+ sigma_t = sigmas[i]
110
+ pseudo_inv = pseudo_inv + v * (sigma_next - sigma_t)
111
+ return pseudo_inv
src/flair/pipelines/utils.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ def eps_from_v(z_0, z_t, sigma_t):
6
+ return (z_t - z_0) / sigma_t
7
+
8
+ def v_to_eps(v, t, x_t):
9
+ """
10
+ function to compute the epsilon parametrization from the velocity field
11
+ with x_t = t * x_0 + (1 - t) * x_1 with x_0 ~ N(0,I)
12
+ """
13
+ eps_t = (1-t)*v + x_t
14
+ return eps_t
15
+
16
+
17
+ def clip_gradients(gradients, clip_value):
18
+ grad_norm = gradients.norm(dim=2)
19
+ mask = grad_norm > clip_value
20
+ mask_exp = mask[:, :, None].expand_as(gradients)
21
+ gradients[mask_exp] = (
22
+ gradients[mask_exp]
23
+ / grad_norm[:, :, None].expand_as(gradients)[mask_exp]
24
+ * clip_value
25
+ )
26
+ return gradients
27
+
28
+
29
+ class Adam:
30
+
31
+ def __init__(self, parameters, lr=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8):
32
+ self.lr = lr
33
+ self.beta1 = beta1
34
+ self.beta2 = beta2
35
+ self.epsilon = epsilon
36
+ self.t = 0
37
+ self.m = torch.zeros_like(parameters)
38
+ self.v = torch.zeros_like(parameters)
39
+
40
+ def step(self, params, grad) -> torch.Tensor:
41
+ self.t += 1
42
+
43
+ self.m = self.beta1 * self.m + (1 - self.beta1) * grad
44
+ self.v = self.beta2 * self.v + (1 - self.beta2) * grad**2
45
+
46
+ m_hat = self.m / (1 - self.beta1**self.t)
47
+ v_hat = self.v / (1 - self.beta2**self.t)
48
+
49
+ # check if self.lr is callable
50
+ if callable(self.lr):
51
+ lr = self.lr(self.t - 1)
52
+ else:
53
+ lr = self.lr
54
+ update = lr * m_hat / (torch.sqrt(v_hat) + self.epsilon)
55
+
56
+ return params - update
57
+
58
+
59
+ def make_cosine_decay_schedule(
60
+ init_value: float,
61
+ total_steps: int,
62
+ alpha: float = 0.0,
63
+ exponent: float = 1.0,
64
+ warmup_steps=0,
65
+ ):
66
+ def schedule(count):
67
+ if count < warmup_steps:
68
+ # linear up
69
+ return (init_value / warmup_steps) * count
70
+ else:
71
+ # half cosine down
72
+ decay_steps = total_steps - warmup_steps
73
+ count = min(count - warmup_steps, decay_steps)
74
+ cosine_decay = 0.5 * (1 + np.cos(np.pi * count / decay_steps))
75
+ decayed = (1 - alpha) * cosine_decay**exponent + alpha
76
+ return init_value * decayed
77
+
78
+ return schedule
79
+
80
+
81
+ def make_linear_decay_schedule(
82
+ init_value: float, total_steps: int, final_value: float = 0, warmup_steps=0
83
+ ):
84
+ def schedule(count):
85
+ if count < warmup_steps:
86
+ # linear up
87
+ return (init_value / warmup_steps) * count
88
+ else:
89
+ # linear down
90
+ decay_steps = total_steps - warmup_steps
91
+ count = min(count - warmup_steps, decay_steps)
92
+ return init_value - (init_value - final_value) * count / decay_steps
93
+
94
+ return schedule
95
+
96
+
97
+ def clip_norm_(tensor, max_norm):
98
+ norm = tensor.norm()
99
+ if norm > max_norm:
100
+ tensor.mul_(max_norm / norm)
101
+
102
+
103
+ def lr_warmup(step, warmup_steps):
104
+ return min(1.0, step / max(warmup_steps, 1))
105
+
106
+
107
+ def linear_decay_lambda(step, warmup_steps, decay_steps, total_steps):
108
+ if step < warmup_steps:
109
+ min(1.0, step / max(warmup_steps, 1))
110
+ else:
111
+ # linear down
112
+ # decay_steps = total_steps - warmup_steps
113
+ count = min(step - warmup_steps, decay_steps)
114
+ return 1 - count / decay_steps
src/flair/scheduling.py ADDED
@@ -0,0 +1,882 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Katherine Crowson and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from dataclasses import dataclass
17
+ from typing import List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import torch
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.utils import BaseOutput, logging
24
+ from diffusers.utils.torch_utils import randn_tensor
25
+ from diffusers.schedulers.scheduling_utils import (
26
+ KarrasDiffusionSchedulers,
27
+ SchedulerMixin,
28
+ )
29
+
30
+
31
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
32
+
33
+
34
+ @dataclass
35
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerDiscrete
36
+ class EulerDiscreteSchedulerOutput(BaseOutput):
37
+ """
38
+ Output class for the scheduler's `step` function output.
39
+
40
+ Args:
41
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
42
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
43
+ denoising loop.
44
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
45
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
46
+ `pred_original_sample` can be used to preview progress or for guidance.
47
+ """
48
+
49
+ prev_sample: torch.FloatTensor
50
+ pred_original_sample: Optional[torch.FloatTensor] = None
51
+
52
+
53
+ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
54
+ def betas_for_alpha_bar(
55
+ num_diffusion_timesteps,
56
+ max_beta=0.999,
57
+ alpha_transform_type="cosine",
58
+ ):
59
+ """
60
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
61
+ (1-beta) over time from t = [0,1].
62
+
63
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
64
+ to that part of the diffusion process.
65
+
66
+
67
+ Args:
68
+ num_diffusion_timesteps (`int`): the number of betas to produce.
69
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
70
+ prevent singularities.
71
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
72
+ Choose from `cosine` or `exp`
73
+
74
+ Returns:
75
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
76
+ """
77
+ if alpha_transform_type == "cosine":
78
+
79
+ def alpha_bar_fn(t):
80
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
81
+
82
+ elif alpha_transform_type == "exp":
83
+
84
+ def alpha_bar_fn(t):
85
+ return math.exp(t * -12.0)
86
+
87
+ else:
88
+ raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
89
+
90
+ betas = []
91
+ for i in range(num_diffusion_timesteps):
92
+ t1 = i / num_diffusion_timesteps
93
+ t2 = (i + 1) / num_diffusion_timesteps
94
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
95
+ return torch.tensor(betas, dtype=torch.float32)
96
+
97
+
98
+ # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
99
+ def rescale_zero_terminal_snr(betas):
100
+ """
101
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
102
+
103
+
104
+ Args:
105
+ betas (`torch.FloatTensor`):
106
+ the betas that the scheduler is being initialized with.
107
+
108
+ Returns:
109
+ `torch.FloatTensor`: rescaled betas with zero terminal SNR
110
+ """
111
+ # Convert betas to alphas_bar_sqrt
112
+ alphas = 1.0 - betas
113
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
114
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
115
+
116
+ # Store old values.
117
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
118
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
119
+
120
+ # Shift so the last timestep is zero.
121
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
122
+
123
+ # Scale so the first timestep is back to the old value.
124
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
125
+
126
+ # Convert alphas_bar_sqrt to betas
127
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
128
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
129
+ alphas = torch.cat([alphas_bar[0:1], alphas])
130
+ betas = 1 - alphas
131
+
132
+ return betas
133
+
134
+
135
+ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
136
+ """
137
+ Euler scheduler.
138
+
139
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
140
+ methods the library implements for all schedulers such as loading and saving.
141
+
142
+ Args:
143
+ num_train_timesteps (`int`, defaults to 1000):
144
+ The number of diffusion steps to train the model.
145
+ beta_start (`float`, defaults to 0.0001):
146
+ The starting `beta` value of inference.
147
+ beta_end (`float`, defaults to 0.02):
148
+ The final `beta` value.
149
+ beta_schedule (`str`, defaults to `"linear"`):
150
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
151
+ `linear` or `scaled_linear`.
152
+ trained_betas (`np.ndarray`, *optional*):
153
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
154
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
155
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
156
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
157
+ Video](https://imagen.research.google/video/paper.pdf) paper).
158
+ interpolation_type(`str`, defaults to `"linear"`, *optional*):
159
+ The interpolation type to compute intermediate sigmas for the scheduler denoising steps. Should be on of
160
+ `"linear"` or `"log_linear"`.
161
+ use_karras_sigmas (`bool`, *optional*, defaults to `False`):
162
+ Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
163
+ the sigmas are determined according to a sequence of noise levels {σi}.
164
+ timestep_spacing (`str`, defaults to `"linspace"`):
165
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
166
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
167
+ steps_offset (`int`, defaults to 0):
168
+ An offset added to the inference steps. You can use a combination of `offset=1` and
169
+ `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
170
+ Diffusion.
171
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
172
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
173
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
174
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
175
+ """
176
+
177
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
178
+ order = 1
179
+
180
+ @register_to_config
181
+ def __init__(
182
+ self,
183
+ num_train_timesteps: int = 1000,
184
+ beta_start: float = 0.0001,
185
+ beta_end: float = 0.02,
186
+ beta_schedule: str = "linear",
187
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
188
+ prediction_type: str = "epsilon",
189
+ interpolation_type: str = "linear",
190
+ use_karras_sigmas: Optional[bool] = False,
191
+ sigma_min: Optional[float] = None,
192
+ sigma_max: Optional[float] = None,
193
+ timestep_spacing: str = "linspace",
194
+ timestep_type: str = "discrete", # can be "discrete" or "continuous"
195
+ steps_offset: int = 0,
196
+ rescale_betas_zero_snr: bool = False,
197
+ ):
198
+ if trained_betas is not None:
199
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
200
+ elif beta_schedule == "linear":
201
+ self.betas = torch.linspace(
202
+ beta_start, beta_end, num_train_timesteps, dtype=torch.float32
203
+ )
204
+ elif beta_schedule == "scaled_linear":
205
+ # this schedule is very specific to the latent diffusion model.
206
+ self.betas = (
207
+ torch.linspace(
208
+ beta_start**0.5,
209
+ beta_end**0.5,
210
+ num_train_timesteps,
211
+ dtype=torch.float32,
212
+ )
213
+ ** 2
214
+ )
215
+ elif beta_schedule == "squaredcos_cap_v2":
216
+ # Glide cosine schedule
217
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
218
+ else:
219
+ raise NotImplementedError(
220
+ f"{beta_schedule} does is not implemented for {self.__class__}"
221
+ )
222
+
223
+ if rescale_betas_zero_snr:
224
+ self.betas = rescale_zero_terminal_snr(self.betas)
225
+
226
+ self.alphas = 1.0 - self.betas
227
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
228
+
229
+ if rescale_betas_zero_snr:
230
+ # Close to 0 without being 0 so first sigma is not inf
231
+ # FP16 smallest positive subnormal works well here
232
+ self.alphas_cumprod[-1] = 2**-24
233
+
234
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
235
+ timesteps = np.linspace(
236
+ 0, num_train_timesteps - 1, num_train_timesteps, dtype=float
237
+ )[::-1].copy()
238
+
239
+ sigmas = torch.from_numpy(sigmas[::-1].copy()).to(dtype=torch.float32)
240
+ timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
241
+
242
+ # setable values
243
+ self.num_inference_steps = None
244
+
245
+ # TODO: Support the full EDM scalings for all prediction types and timestep types
246
+ if timestep_type == "continuous" and prediction_type == "v_prediction":
247
+ self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas])
248
+ else:
249
+ self.timesteps = timesteps
250
+
251
+ self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
252
+
253
+ self.is_scale_input_called = False
254
+ self.use_karras_sigmas = use_karras_sigmas
255
+
256
+ self._step_index = None
257
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
258
+ self.sigmas_for_inference = False
259
+
260
+ @property
261
+ def init_noise_sigma(self):
262
+ # standard deviation of the initial noise distribution
263
+ max_sigma = (
264
+ max(self.sigmas) if isinstance(self.sigmas, list) else self.sigmas.max()
265
+ )
266
+ if self.config.timestep_spacing in ["linspace", "trailing"]:
267
+ return max_sigma
268
+
269
+ return (max_sigma**2 + 1) ** 0.5
270
+
271
+ @property
272
+ def step_index(self):
273
+ """
274
+ The index counter for current timestep. It will increae 1 after each scheduler step.
275
+ """
276
+ return self._step_index
277
+
278
+ def scale_model_input(
279
+ self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
280
+ ) -> torch.FloatTensor:
281
+ """
282
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
283
+ current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
284
+
285
+ Args:
286
+ sample (`torch.FloatTensor`):
287
+ The input sample.
288
+ timestep (`int`, *optional*):
289
+ The current timestep in the diffusion chain.
290
+
291
+ Returns:
292
+ `torch.FloatTensor`:
293
+ A scaled input sample.
294
+ """
295
+ if self.step_index is None:
296
+ self._init_step_index(timestep)
297
+
298
+ sigma = self.sigmas[self.step_index]
299
+ sample = sample / ((sigma**2 + 1) ** 0.5)
300
+
301
+ self.is_scale_input_called = True
302
+ return sample
303
+
304
+ def set_timesteps(
305
+ self,
306
+ num_inference_steps: int,
307
+ timesteps=None,
308
+ device: Union[str, torch.device] = None,
309
+ ):
310
+ """
311
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
312
+
313
+ Args:
314
+ num_inference_steps (`int`):
315
+ The number of diffusion steps used when generating samples with a pre-trained model.
316
+ device (`str` or `torch.device`, *optional*):
317
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
318
+ """
319
+ self.num_inference_steps = num_inference_steps
320
+
321
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
322
+ if self.config.timestep_spacing == "linspace":
323
+ timesteps = np.linspace(
324
+ 0,
325
+ self.config.num_train_timesteps - 1,
326
+ num_inference_steps,
327
+ dtype=np.float32,
328
+ )[::-1].copy()
329
+ elif self.config.timestep_spacing == "leading":
330
+ step_ratio = self.config.num_train_timesteps // self.num_inference_steps
331
+ # creates integer timesteps by multiplying by ratio
332
+ # casting to int to avoid issues when num_inference_step is power of 3
333
+ timesteps = (
334
+ (np.arange(0, num_inference_steps) * step_ratio)
335
+ .round()[::-1]
336
+ .copy()
337
+ .astype(np.float32)
338
+ )
339
+ timesteps += self.config.steps_offset
340
+ elif self.config.timestep_spacing == "trailing":
341
+ step_ratio = self.config.num_train_timesteps / self.num_inference_steps
342
+ # creates integer timesteps by multiplying by ratio
343
+ # casting to int to avoid issues when num_inference_step is power of 3
344
+ timesteps = (
345
+ (np.arange(self.config.num_train_timesteps, 0, -step_ratio))
346
+ .round()
347
+ .copy()
348
+ .astype(np.float32)
349
+ )
350
+ timesteps -= 1
351
+ else:
352
+ raise ValueError(
353
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
354
+ )
355
+
356
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
357
+ log_sigmas = np.log(sigmas)
358
+
359
+ if self.config.interpolation_type == "linear":
360
+ sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
361
+ elif self.config.interpolation_type == "log_linear":
362
+ sigmas = (
363
+ torch.linspace(
364
+ np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1
365
+ )
366
+ .exp()
367
+ .numpy()
368
+ )
369
+ else:
370
+ raise ValueError(
371
+ f"{self.config.interpolation_type} is not implemented. Please specify interpolation_type to either"
372
+ " 'linear' or 'log_linear'"
373
+ )
374
+
375
+ if self.use_karras_sigmas:
376
+ sigmas = self._convert_to_karras(
377
+ in_sigmas=sigmas, num_inference_steps=self.num_inference_steps
378
+ )
379
+ timesteps = np.array(
380
+ [self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]
381
+ )
382
+
383
+ sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
384
+
385
+ # TODO: Support the full EDM scalings for all prediction types and timestep types
386
+ if (
387
+ self.config.timestep_type == "continuous"
388
+ and self.config.prediction_type == "v_prediction"
389
+ ):
390
+ self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas]).to(
391
+ device=device
392
+ )
393
+ else:
394
+ self.timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(
395
+ device=device
396
+ )
397
+
398
+ self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
399
+ self._step_index = None
400
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
401
+ self.sigmas_for_inference = True
402
+
403
+ def _sigma_to_t(self, sigma, log_sigmas):
404
+ # get log sigma
405
+ log_sigma = np.log(np.maximum(sigma, 1e-10))
406
+
407
+ # get distribution
408
+ dists = log_sigma - log_sigmas[:, np.newaxis]
409
+
410
+ # get sigmas range
411
+ low_idx = (
412
+ np.cumsum((dists >= 0), axis=0)
413
+ .argmax(axis=0)
414
+ .clip(max=log_sigmas.shape[0] - 2)
415
+ )
416
+ high_idx = low_idx + 1
417
+
418
+ low = log_sigmas[low_idx]
419
+ high = log_sigmas[high_idx]
420
+
421
+ # interpolate sigmas
422
+ w = (low - log_sigma) / (low - high)
423
+ w = np.clip(w, 0, 1)
424
+
425
+ # transform interpolation to time range
426
+ t = (1 - w) * low_idx + w * high_idx
427
+ t = t.reshape(sigma.shape)
428
+ return t
429
+
430
+ # Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17
431
+ def _convert_to_karras(
432
+ self, in_sigmas: torch.FloatTensor, num_inference_steps
433
+ ) -> torch.FloatTensor:
434
+ """Constructs the noise schedule of Karras et al. (2022)."""
435
+
436
+ # Hack to make sure that other schedulers which copy this function don't break
437
+ # TODO: Add this logic to the other schedulers
438
+ if hasattr(self.config, "sigma_min"):
439
+ sigma_min = self.config.sigma_min
440
+ else:
441
+ sigma_min = None
442
+
443
+ if hasattr(self.config, "sigma_max"):
444
+ sigma_max = self.config.sigma_max
445
+ else:
446
+ sigma_max = None
447
+
448
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
449
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
450
+
451
+ rho = 7.0 # 7.0 is the value used in the paper
452
+ ramp = np.linspace(0, 1, num_inference_steps)
453
+ min_inv_rho = sigma_min ** (1 / rho)
454
+ max_inv_rho = sigma_max ** (1 / rho)
455
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
456
+ return sigmas
457
+
458
+ def derivative_sigma_t(self, step):
459
+ rho = 7.0
460
+ sigma_max = self.sigmas[0]
461
+ sigma_min = self.sigmas[-1]
462
+ min_inv_rho = sigma_min ** (1 / rho)
463
+ max_inv_rho = sigma_max ** (1 / rho)
464
+ derivative_sigma_t = (
465
+ -7
466
+ * (max_inv_rho + (step) / 1000 * (min_inv_rho - max_inv_rho)) ** 6
467
+ * (min_inv_rho - max_inv_rho)
468
+ )
469
+ return derivative_sigma_t
470
+
471
+ def _init_step_index(self, timestep):
472
+ if isinstance(timestep, torch.Tensor):
473
+ timestep = timestep.to(self.timesteps.device)
474
+
475
+ index_candidates = (self.timesteps == timestep).nonzero()
476
+
477
+ # The sigma index that is taken for the **very** first `step`
478
+ # is always the second index (or the last index if there is only 1)
479
+ # This way we can ensure we don't accidentally skip a sigma in
480
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
481
+ if len(index_candidates) > 1:
482
+ step_index = index_candidates[1]
483
+ else:
484
+ step_index = index_candidates[0]
485
+
486
+ self._step_index = step_index.item()
487
+
488
+ def step(
489
+ self,
490
+ model_output: torch.FloatTensor,
491
+ timestep: Union[float, torch.FloatTensor],
492
+ sample: torch.FloatTensor,
493
+ s_churn: float = 0.0,
494
+ s_tmin: float = 0.0,
495
+ s_tmax: float = float("inf"),
496
+ s_noise: float = 1.0,
497
+ generator: Optional[torch.Generator] = None,
498
+ return_dict: bool = True,
499
+ likelihood_grad=None,
500
+ likelihood_weight=None,
501
+ regularizer_weight=None,
502
+ step_index: Optional[int] = None,
503
+ ) -> Union[EulerDiscreteSchedulerOutput, Tuple]:
504
+ """
505
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
506
+ process from the learned model outputs (most often the predicted noise).
507
+
508
+ Args:
509
+ model_output (`torch.FloatTensor`):
510
+ The direct output from learned diffusion model.
511
+ timestep (`float`):
512
+ The current discrete timestep in the diffusion chain.
513
+ sample (`torch.FloatTensor`):
514
+ A current instance of a sample created by the diffusion process.
515
+ s_churn (`float`):
516
+ s_tmin (`float`):
517
+ s_tmax (`float`):
518
+ s_noise (`float`, defaults to 1.0):
519
+ Scaling factor for noise added to the sample.
520
+ generator (`torch.Generator`, *optional*):
521
+ A random number generator.
522
+ return_dict (`bool`):
523
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
524
+ tuple.
525
+
526
+ Returns:
527
+ [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
528
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
529
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
530
+ """
531
+
532
+ if (
533
+ isinstance(timestep, int)
534
+ or isinstance(timestep, torch.IntTensor)
535
+ or isinstance(timestep, torch.LongTensor)
536
+ ):
537
+ raise ValueError(
538
+ (
539
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
540
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
541
+ " one of the `scheduler.timesteps` as a timestep."
542
+ ),
543
+ )
544
+
545
+ if not self.is_scale_input_called:
546
+ logger.warning(
547
+ "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
548
+ "See `StableDiffusionPipeline` for a usage example."
549
+ )
550
+
551
+ # if self.step_index is None:
552
+ # self._init_step_index(timestep)
553
+ if step_index is None:
554
+ if self.step_index is None:
555
+ self._init_step_index(timestep)
556
+ step_index = torch.tensor(self.step_index).unsqueeze(0)
557
+
558
+ # Upcast to avoid precision issues when computing prev_sample
559
+ sample = sample.to(torch.float32)
560
+ if likelihood_grad is not None:
561
+ likelihood_grad = likelihood_grad.to(torch.float32)
562
+
563
+ sigma = self.sigmas[step_index][:, None, None, None].to(model_output.device)
564
+
565
+ gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1)
566
+ condition = torch.logical_and(s_tmin <= sigma, sigma <= s_tmax)
567
+ gamma = torch.where(condition, gamma, torch.tensor(0))
568
+
569
+ noise = randn_tensor(
570
+ model_output.shape,
571
+ dtype=model_output.dtype,
572
+ device=model_output.device,
573
+ generator=generator,
574
+ )
575
+
576
+ eps = noise * s_noise
577
+ sigma_hat = sigma * (gamma + 1)
578
+
579
+ # if gamma > 0:
580
+ sample = torch.where(
581
+ gamma > 0, sample + eps * (sigma_hat**2 - sigma**2) ** 0.5, sample
582
+ )
583
+ # sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
584
+
585
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
586
+ # NOTE: "original_sample" should not be an expected prediction_type but is left in for
587
+ # backwards compatibility
588
+ if (
589
+ self.config.prediction_type == "original_sample"
590
+ or self.config.prediction_type == "sample"
591
+ ):
592
+ pred_original_sample = model_output
593
+ elif self.config.prediction_type == "epsilon":
594
+ pred_original_sample = sample - sigma_hat * model_output
595
+ elif self.config.prediction_type == "v_prediction":
596
+ # denoised = model_output * c_out + input * c_skip
597
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (
598
+ sample / (sigma**2 + 1)
599
+ )
600
+ else:
601
+ raise ValueError(
602
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
603
+ )
604
+
605
+ # 2. Convert to an ODE derivative
606
+ derivative = (sample - pred_original_sample) / sigma_hat
607
+ if likelihood_grad is not None:
608
+ likelihood_grad = likelihood_grad * (
609
+ torch.norm(derivative) / torch.norm(likelihood_grad)
610
+ )
611
+ derivative = (
612
+ regularizer_weight * derivative + likelihood_weight * likelihood_grad
613
+ )
614
+
615
+ dt = (
616
+ self.sigmas[step_index + 1].to(sigma_hat.device)[:, None, None, None]
617
+ - sigma_hat
618
+ )
619
+
620
+ prev_sample = sample + derivative * dt
621
+
622
+ # Cast sample back to model compatible dtype
623
+ prev_sample = prev_sample.to(model_output.dtype)
624
+
625
+ # upon completion increase step index by one
626
+ if self._step_index is not None:
627
+ self._step_index += 1
628
+
629
+ if not return_dict:
630
+ return (prev_sample,)
631
+
632
+ return EulerDiscreteSchedulerOutput(
633
+ prev_sample=prev_sample, pred_original_sample=pred_original_sample
634
+ )
635
+
636
+ def add_noise(
637
+ self,
638
+ original_samples: torch.FloatTensor,
639
+ noise: torch.FloatTensor,
640
+ timesteps: torch.FloatTensor,
641
+ ) -> torch.FloatTensor:
642
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
643
+ sigmas = self.sigmas.to(
644
+ device=original_samples.device, dtype=original_samples.dtype
645
+ )
646
+ if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
647
+ # mps does not support float64
648
+ schedule_timesteps = self.timesteps.to(
649
+ original_samples.device, dtype=torch.float32
650
+ )
651
+ timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
652
+ else:
653
+ schedule_timesteps = self.timesteps.to(original_samples.device)
654
+ timesteps = timesteps.to(original_samples.device)
655
+
656
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
657
+
658
+ sigma = sigmas[step_indices].flatten()
659
+ while len(sigma.shape) < len(original_samples.shape):
660
+ sigma = sigma.unsqueeze(-1)
661
+
662
+ noisy_samples = original_samples + noise * sigma
663
+ return noisy_samples
664
+
665
+ def __len__(self):
666
+ return self.config.num_train_timesteps
667
+
668
+
669
+ class HeunDiscreteScheduler(EulerDiscreteScheduler):
670
+
671
+ def set_timesteps(
672
+ self,
673
+ num_inference_steps: int,
674
+ timesteps=None,
675
+ device: Union[str, torch.device] = None,
676
+ ):
677
+ """
678
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
679
+
680
+ Args:
681
+ num_inference_steps (`int`):
682
+ The number of diffusion steps used when generating samples with a pre-trained model.
683
+ device (`str` or `torch.device`, *optional*):
684
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
685
+ """
686
+ self.num_inference_steps = num_inference_steps
687
+
688
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
689
+ if self.config.timestep_spacing == "linspace":
690
+ timesteps = np.linspace(
691
+ 0,
692
+ self.config.num_train_timesteps - 1,
693
+ num_inference_steps,
694
+ dtype=np.float32,
695
+ )[::-1].copy()
696
+ elif self.config.timestep_spacing == "leading":
697
+ step_ratio = self.config.num_train_timesteps // self.num_inference_steps
698
+ # creates integer timesteps by multiplying by ratio
699
+ # casting to int to avoid issues when num_inference_step is power of 3
700
+ timesteps = (
701
+ (np.arange(0, num_inference_steps) * step_ratio)
702
+ .round()[::-1]
703
+ .copy()
704
+ .astype(np.float32)
705
+ )
706
+ timesteps += self.config.steps_offset
707
+ elif self.config.timestep_spacing == "trailing":
708
+ step_ratio = self.config.num_train_timesteps / self.num_inference_steps
709
+ # creates integer timesteps by multiplying by ratio
710
+ # casting to int to avoid issues when num_inference_step is power of 3
711
+ timesteps = (
712
+ (np.arange(self.config.num_train_timesteps, 0, -step_ratio))
713
+ .round()
714
+ .copy()
715
+ .astype(np.float32)
716
+ )
717
+ timesteps -= 1
718
+ else:
719
+ raise ValueError(
720
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
721
+ )
722
+
723
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
724
+ log_sigmas = np.log(sigmas)
725
+
726
+ if self.config.interpolation_type == "linear":
727
+ sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
728
+ elif self.config.interpolation_type == "log_linear":
729
+ sigmas = (
730
+ torch.linspace(
731
+ np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1
732
+ )
733
+ .exp()
734
+ .numpy()
735
+ )
736
+ else:
737
+ raise ValueError(
738
+ f"{self.config.interpolation_type} is not implemented. Please specify interpolation_type to either"
739
+ " 'linear' or 'log_linear'"
740
+ )
741
+
742
+ if self.use_karras_sigmas:
743
+ sigmas = self._convert_to_karras(
744
+ in_sigmas=sigmas, num_inference_steps=self.num_inference_steps
745
+ )
746
+ timesteps = np.array(
747
+ [self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]
748
+ )
749
+
750
+ sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
751
+
752
+ sigmas = torch.cat([sigmas[:1], sigmas[1:].repeat_interleave(2)])
753
+
754
+ # TODO: Support the full EDM scalings for all prediction types and timestep types
755
+ if (
756
+ self.config.timestep_type == "continuous"
757
+ and self.config.prediction_type == "v_prediction"
758
+ ):
759
+ self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas]).to(
760
+ device=device
761
+ )
762
+ else:
763
+ self.timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(
764
+ device=device
765
+ )
766
+
767
+ self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
768
+ self._step_index = None
769
+ self.sigmas = self.sigmas.to("cpu")
770
+
771
+ # empty dt and derivative
772
+ self.prev_derivative = None
773
+ self.dt = None
774
+
775
+ def step(
776
+ self,
777
+ model_output: Union[torch.Tensor, np.ndarray],
778
+ timestep: Union[float, torch.Tensor],
779
+ sample: Union[torch.Tensor, np.ndarray],
780
+ return_dict: bool = True,
781
+ ) -> Union[EulerDiscreteSchedulerOutput, Tuple]:
782
+ """
783
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
784
+ process from the learned model outputs (most often the predicted noise).
785
+
786
+ Args:
787
+ model_output (`torch.Tensor`):
788
+ The direct output from learned diffusion model.
789
+ timestep (`float`):
790
+ The current discrete timestep in the diffusion chain.
791
+ sample (`torch.Tensor`):
792
+ A current instance of a sample created by the diffusion process.
793
+ return_dict (`bool`):
794
+ Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
795
+
796
+ Returns:
797
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
798
+ If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
799
+ tuple is returned where the first element is the sample tensor.
800
+ """
801
+ if self.step_index is None:
802
+ self._init_step_index(timestep)
803
+
804
+ if self.state_in_first_order:
805
+ sigma = self.sigmas[self.step_index]
806
+ sigma_next = self.sigmas[self.step_index + 1]
807
+ else:
808
+ # 2nd order / Heun's method
809
+ sigma = self.sigmas[self.step_index - 1]
810
+ sigma_next = self.sigmas[self.step_index]
811
+
812
+ # currently only gamma=0 is supported. This usually works best anyways.
813
+ # We can support gamma in the future but then need to scale the timestep before
814
+ # passing it to the model which requires a change in API
815
+ gamma = 0
816
+ sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now
817
+
818
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
819
+ if self.config.prediction_type == "epsilon":
820
+ sigma_input = sigma_hat if self.state_in_first_order else sigma_next
821
+ pred_original_sample = sample - sigma_input * model_output
822
+ elif self.config.prediction_type == "v_prediction":
823
+ sigma_input = sigma_hat if self.state_in_first_order else sigma_next
824
+ pred_original_sample = model_output * (
825
+ -sigma_input / (sigma_input**2 + 1) ** 0.5
826
+ ) + (sample / (sigma_input**2 + 1))
827
+ elif self.config.prediction_type == "sample":
828
+ pred_original_sample = model_output
829
+ else:
830
+ raise ValueError(
831
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
832
+ )
833
+
834
+ if self.config.clip_sample:
835
+ pred_original_sample = pred_original_sample.clamp(
836
+ -self.config.clip_sample_range, self.config.clip_sample_range
837
+ )
838
+
839
+ if self.state_in_first_order:
840
+ # 2. Convert to an ODE derivative for 1st order
841
+ derivative = (sample - pred_original_sample) / sigma_hat
842
+ # 3. delta timestep
843
+ dt = sigma_next - sigma_hat
844
+
845
+ # store for 2nd order step
846
+ self.prev_derivative = derivative
847
+ self.dt = dt
848
+ self.sample = sample
849
+ else:
850
+ # 2. 2nd order / Heun's method
851
+ derivative = (sample - pred_original_sample) / sigma_next
852
+ derivative = (self.prev_derivative + derivative) / 2
853
+
854
+ # 3. take prev timestep & sample
855
+ dt = self.dt
856
+ sample = self.sample
857
+
858
+ # free dt and derivative
859
+ # Note, this puts the scheduler in "first order mode"
860
+ self.prev_derivative = None
861
+ self.dt = None
862
+ self.sample = None
863
+
864
+ prev_sample = sample + derivative * dt
865
+
866
+ # upon completion increase step index by one
867
+ self._step_index += 1
868
+
869
+ if not return_dict:
870
+ return (prev_sample,)
871
+
872
+ return EulerDiscreteSchedulerOutput(
873
+ prev_sample=prev_sample, pred_original_sample=pred_original_sample
874
+ )
875
+
876
+ @property
877
+ def state_in_first_order(self):
878
+ return self.dt is None
879
+
880
+ @property
881
+ def order(self):
882
+ return 2
src/flair/utils.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from pathlib import Path
3
+ import os
4
+ import torch
5
+ from torchmetrics.functional.image import (
6
+ peak_signal_noise_ratio,
7
+ learned_perceptual_image_patch_similarity,
8
+ )
9
+ from PIL import Image
10
+ from skimage.color import rgb2lab, lab2rgb
11
+ import numpy as np
12
+ import cv2
13
+ from torchvision import transforms
14
+
15
+
16
+ RESAMPLE_MODE = Image.BICUBIC
17
+
18
+
19
+ def skip_iterator(iterator, skip):
20
+ for i, item in enumerate(iterator):
21
+ if i % skip == 0:
22
+ yield item
23
+
24
+
25
+ def generate_output_structure(output_dir, subfolders=[]):
26
+ """
27
+ Generate a directory structure for the output. and return the paths to the subfolders. as template
28
+ """
29
+ output_dir = Path(output_dir)
30
+ output_dir.mkdir(exist_ok=True, parents=True)
31
+ output_paths = []
32
+ for subfolder in subfolders:
33
+ output_paths.append(Path(os.path.join(output_dir, subfolder)))
34
+ (output_paths[-1]).mkdir(exist_ok=True, parents=True)
35
+ output_paths[-1] = os.path.join(output_paths[-1], "{}.png")
36
+ return output_paths
37
+
38
+
39
+ def find_files(path, ext="png"):
40
+ if os.path.isdir(path):
41
+ path = Path(path)
42
+ sorted_files = sorted(list(path.glob(f"*.{ext}")))
43
+ return sorted_files
44
+ else:
45
+ return [path]
46
+
47
+
48
+ def load_guidance_image(path, size=None):
49
+ """
50
+ Load an image and convert it to a tensor.
51
+ Args: path to the image
52
+ returns: tensor of the image of shape (1, 3, H, W)
53
+ """
54
+ img = Image.open(path)
55
+ img = img.convert("RGB")
56
+ tf = transforms.Compose([
57
+ transforms.Resize(size),
58
+ transforms.CenterCrop(size),
59
+ transforms.ToTensor()
60
+ ])
61
+ img = tf(img) * 2 - 1
62
+ return img.unsqueeze(0)
63
+
64
+
65
+ def yield_images(path, ext="png", size=None):
66
+ files = find_files(path, ext)
67
+ for file in files:
68
+ yield load_guidance_image(file, size)
69
+
70
+
71
+ def yield_videos(paths, ext="png", H=None, W=None, n_frames=61):
72
+ for path in paths:
73
+ yield read_video(path, H, W, n_frames)
74
+
75
+
76
+ def read_video(path, H=None, W=None, n_frames=61) -> list[Image]:
77
+ path = Path(path)
78
+ frames = []
79
+
80
+ if Path(path).is_dir():
81
+ files = sorted(list(path.glob("*.png")))
82
+ for file in files[:n_frames]:
83
+ image = Image.open(file)
84
+ image.load()
85
+ if H is not None and W is not None:
86
+ image = image.resize((W, H), resample=Image.BICUBIC)
87
+ # to tensor
88
+ image = (
89
+ torch.tensor(np.array(image), dtype=torch.float32).permute(2, 0, 1)
90
+ / 255.0
91
+ * 2
92
+ - 1
93
+ )
94
+ frames.append(image)
95
+ H, W = frames[0].size()[-2:]
96
+ frames = torch.stack(frames).unsqueeze(0)
97
+ return frames, (10, H, W)
98
+
99
+ capture = cv2.VideoCapture(str(path))
100
+ fps = capture.get(cv2.CAP_PROP_FPS)
101
+
102
+ while True:
103
+ success, frame = capture.read()
104
+ if not success:
105
+ break
106
+
107
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
108
+ frame = Image.fromarray(frame)
109
+ if H is not None and W is not None:
110
+ frame = frame.resize((W, H), resample=Image.BICUBIC)
111
+ # to tensor
112
+ frame = (
113
+ torch.tensor(np.array(frame), dtype=torch.float32).permute(2, 0, 1)
114
+ / 255.0
115
+ * 2
116
+ - 1
117
+ )
118
+ frames.append(frame)
119
+
120
+ capture.release()
121
+ # to torch
122
+ frames = torch.stack(frames).unsqueeze(0)
123
+ return frames, (fps, W, H)
124
+
125
+
126
+ def resize_video(
127
+ video: list[Image], width, height, resample_mode=RESAMPLE_MODE
128
+ ) -> list[Image]:
129
+ frames_lr = []
130
+ for frame in video:
131
+ frame_lr = frame.resize((width, height), resample_mode)
132
+ frames_lr.append(frame_lr)
133
+ return frames_lr
134
+
135
+
136
+ def export_to_video(
137
+ video_frames,
138
+ output_video_path=None,
139
+ fps=8,
140
+ put_numbers=False,
141
+ annotations=None,
142
+ fourcc="mp4v",
143
+ ):
144
+ fourcc = cv2.VideoWriter_fourcc(*fourcc) # codec
145
+ writer = cv2.VideoWriter(output_video_path, fourcc, fps, video_frames[0].size)
146
+ for i, frame in enumerate(video_frames):
147
+ frame = cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR)
148
+
149
+ if put_numbers:
150
+ text_position = (frame.shape[1] - 60, 30)
151
+ font = cv2.FONT_HERSHEY_SIMPLEX
152
+ font_scale = 1
153
+ font_color = (255, 255, 255)
154
+ line_type = 2
155
+ cv2.putText(
156
+ frame,
157
+ f"{i + 1}",
158
+ text_position,
159
+ font,
160
+ font_scale,
161
+ font_color,
162
+ line_type,
163
+ )
164
+
165
+ if annotations:
166
+ annotation = annotations[i]
167
+ frame = draw_bodypose(
168
+ frame, annotation["candidates"], annotation["subsets"]
169
+ )
170
+
171
+ writer.write(frame)
172
+ writer.release()
173
+
174
+
175
+ def export_images(frames: list[Image], dir_name):
176
+ dir_name = Path(dir_name)
177
+ dir_name.mkdir(exist_ok=True, parents=True)
178
+ for i, frame in enumerate(frames):
179
+ frame.save(dir_name / f"{i:05d}.png")
180
+
181
+
182
+ def vid2tensor(images: list[Image]) -> torch.Tensor:
183
+ # PIL to numpy
184
+ if not isinstance(images, list):
185
+ raise ValueError()
186
+ images = [np.array(image).astype(np.float32) / 255.0 for image in images]
187
+ images = np.stack(images, axis=0)
188
+
189
+ if images.ndim == 3:
190
+ # L mode, add luminance channel
191
+ images = np.expand_dims(images, -1)
192
+
193
+ # numpy to torch
194
+ images = torch.from_numpy(images.transpose(0, 3, 1, 2))
195
+ return images
196
+
197
+
198
+ def compute_metrics(
199
+ source: list[Image],
200
+ output: list[Image],
201
+ output_lq: list[Image],
202
+ target: list[Image],
203
+ ) -> dict:
204
+ psnr_ab = torch.tensor(
205
+ np.mean(compute_color_metrics(output, target)["psnr_ab"])
206
+ ).float()
207
+
208
+ source = vid2tensor(source)
209
+ output = vid2tensor(output)
210
+ output_lq = vid2tensor(output_lq)
211
+ target = vid2tensor(target)
212
+
213
+ mse = ((output - target) ** 2).mean()
214
+ psnr = peak_signal_noise_ratio(output, target, data_range=1.0, dim=(1, 2, 3))
215
+ # lpips = learned_perceptual_image_patch_similarity(output, target)
216
+
217
+ mse_source = ((output_lq - source) ** 2).mean()
218
+ psnr_source = peak_signal_noise_ratio(
219
+ output_lq, source, data_range=1.0, dim=(1, 2, 3)
220
+ )
221
+
222
+ return {
223
+ "mse": mse.detach().cpu().item(),
224
+ "psnr": psnr.detach().cpu().item(),
225
+ "psnr_ab": psnr_ab.detach().cpu().item(),
226
+ "mse_source": mse_source.detach().cpu().item(),
227
+ "psnr_source": psnr_source.detach().cpu().item(),
228
+ }
229
+
230
+
231
+ def compute_psnr_ab(x, y_gt, pp_max=202.33542248):
232
+ """Computes the PSNR of the ab color channels.
233
+
234
+ Note that the CIE-Lab space is asymmetric.
235
+ The maximum size for the 2 channels of the ab subspace is approximately 202.3354...
236
+ pp_max: Approximated maximum swing for the ab channels of the CIE-Lab color space
237
+ max_{x \in CIE-Lab} {x_a x_b} - min_{x \in CIE-Lab} {x_a x_b}
238
+ """
239
+ assert (
240
+ len(x.shape) == 3
241
+ ), f"Expecting data of the size HW2 but found {x.shape}; This should be a,b channels of CIE-Lab Space"
242
+ assert (
243
+ len(y_gt.shape) == 3
244
+ ), f"Expecting data of the size HW2 but found {y_gt.shape}; This should be a,b channels of CIE-Lab Space"
245
+ assert (
246
+ x.shape == y_gt.shape
247
+ ), f"Expecting data to have identical shape but found {y_gt.shape} != {x.shape}"
248
+
249
+ H, W, C = x.shape
250
+ assert (
251
+ C == 2
252
+ ), f"This function assumes that both x & y are both the ab channels of the CIE-Lab Space"
253
+
254
+ MSE = np.sum((x - y_gt) ** 2) / (H * W * C) # C=2, two channels
255
+ MSE = np.clip(MSE, 1e-12, np.inf)
256
+
257
+ PSNR_ab = 10 * np.log10(pp_max**2) - 10 * np.log10(MSE)
258
+
259
+ return PSNR_ab
260
+
261
+
262
+ def compute_color_metrics(out: list[Image], target: list[Image]):
263
+ if len(out) != len(target):
264
+ raise ValueError("Videos do not have same length")
265
+
266
+ metrics = {"psnr_ab": []}
267
+
268
+ for out_frame, target_frame in zip(out, target):
269
+ out_frame, target_frame = np.asarray(out_frame), np.asarray(target_frame)
270
+ out_frame_lab, target_frame_lab = rgb2lab(out_frame), rgb2lab(target_frame)
271
+
272
+ psnr_ab = compute_psnr_ab(out_frame_lab[..., 1:3], target_frame_lab[..., 1:3])
273
+ metrics["psnr_ab"].append(psnr_ab.item())
274
+
275
+ return metrics
276
+
277
+
278
+ def to_device(sample, device):
279
+ result = {}
280
+ for key, val in sample.items():
281
+ if isinstance(val, torch.Tensor):
282
+ result[key] = val.to(device)
283
+ elif isinstance(val, list):
284
+ new_val = []
285
+ for e in val:
286
+ if isinstance(e, torch.Tensor):
287
+ new_val.append(e.to(device))
288
+ else:
289
+ new_val.append(val)
290
+ result[key] = new_val
291
+ else:
292
+ result[key] = val
293
+ return result
294
+
295
+
296
+ def seed_all(seed=42):
297
+ random.seed(seed)
298
+ np.random.seed(seed)
299
+ torch.manual_seed(seed)
300
+
301
+
302
+ def adapt_unet(unet, lora_rank=None, in_conv_mode="zeros"):
303
+ # adapt conv_in
304
+ kernel = unet.conv_in.weight.data
305
+ if in_conv_mode == "zeros":
306
+ new_kernel = torch.zeros(320, 4, 3, 3, dtype=kernel.dtype, device=kernel.device)
307
+ elif in_conv_mode == "reflect":
308
+ new_kernel = kernel[:, 4:].clone()
309
+ else:
310
+ raise NotImplementedError
311
+ unet.conv_in.weight.data = torch.cat([kernel, new_kernel], dim=1)
312
+ if in_conv_mode == "reflect":
313
+ unet.conv_in.weight.data *= 2.0 / 3.0
314
+ unet.conv_in.in_channels = 12
315
+
316
+ if lora_rank is not None:
317
+ from peft import LoraConfig
318
+
319
+ types = (torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Linear, torch.nn.Embedding)
320
+ target_modules = [
321
+ (n, m) for n, m in unet.named_modules() if isinstance(m, types)
322
+ ]
323
+ # identify parameters (not modules) that will not be lora'd
324
+ for _, m in target_modules:
325
+ m.requires_grad_(False)
326
+ not_adapted = [p for p in unet.parameters() if p.requires_grad]
327
+
328
+ unet_lora_config = LoraConfig(
329
+ r=lora_rank,
330
+ lora_alpha=lora_rank,
331
+ init_lora_weights="gaussian",
332
+ # target_modules=["to_k", "to_q", "to_v", "to_out.0"],
333
+ target_modules=[n for n, _ in target_modules],
334
+ )
335
+ # the following line sets all parameters except the loras to non-trainable
336
+ unet.add_adapter(unet_lora_config)
337
+
338
+ unet.conv_in.requires_grad_()
339
+ for p in not_adapted:
340
+ p.requires_grad_()
341
+
342
+
343
+ def repeat_infinite(iterable):
344
+ def repeated():
345
+ while True:
346
+ yield from iterable
347
+
348
+ return repeated
349
+
350
+
351
+ class CPUAdam:
352
+ def __init__(self, params, lr=0.001, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
353
+ self.params = list(params)
354
+ self.lr = lr
355
+ self.betas = betas
356
+ self.eps = eps
357
+ self.weight_decay = weight_decay
358
+ # keep this in main memory to save VRAM
359
+ self.state = {
360
+ param: {
361
+ "step": 0,
362
+ "exp_avg": torch.zeros_like(param, device="cpu"),
363
+ "exp_avg_sq": torch.zeros_like(param, device="cpu"),
364
+ }
365
+ for param in self.params
366
+ }
367
+
368
+ def step(self):
369
+ for param in self.params:
370
+ if param.grad is None:
371
+ continue
372
+
373
+ grad = param.grad.data.cpu()
374
+ if self.weight_decay != 0:
375
+ grad.add_(param.data, alpha=self.weight_decay)
376
+
377
+ state = self.state[param]
378
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
379
+ beta1, beta2 = self.betas
380
+
381
+ state["step"] += 1
382
+
383
+ exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
384
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
385
+
386
+ denom = exp_avg_sq.sqrt().add_(self.eps)
387
+
388
+ step_size = (
389
+ self.lr
390
+ * (1 - beta2 ** state["step"]) ** 0.5
391
+ / (1 - beta1 ** state["step"])
392
+ )
393
+
394
+ # param.data.add_((-step_size * (exp_avg / denom)).cuda())
395
+
396
+ param.data.addcdiv_(exp_avg.cuda(), denom.cuda(), value=-step_size)
397
+
398
+ def zero_grad(self):
399
+ for param in self.params:
400
+ param.grad = None
401
+
402
+ def state_dict(self):
403
+ return self.state
404
+
405
+ def set_lr(self, lr):
406
+ self.lr = lr
src/flair/utils/__init__.py ADDED
File without changes
src/flair/utils/blur_util.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import numpy as np
4
+ import scipy
5
+
6
+ from flair.utils.motionblur import Kernel as MotionKernel
7
+
8
+ class Blurkernel(nn.Module):
9
+ def __init__(self, blur_type='gaussian', kernel_size=31, std=3.0, device=None):
10
+ super().__init__()
11
+ self.blur_type = blur_type
12
+ self.kernel_size = kernel_size
13
+ self.std = std
14
+ self.device = device
15
+ self.seq = nn.Sequential(
16
+ nn.ReflectionPad2d(self.kernel_size//2),
17
+ nn.Conv2d(3, 3, self.kernel_size, stride=1, padding=0, bias=False, groups=3)
18
+ )
19
+
20
+ self.weights_init()
21
+
22
+ def forward(self, x):
23
+ return self.seq(x)
24
+
25
+ def weights_init(self):
26
+ if self.blur_type == "gaussian":
27
+ n = np.zeros((self.kernel_size, self.kernel_size))
28
+ n[self.kernel_size // 2,self.kernel_size // 2] = 1
29
+ k = scipy.ndimage.gaussian_filter(n, sigma=self.std)
30
+ k = torch.from_numpy(k)
31
+ self.k = k
32
+ for name, f in self.named_parameters():
33
+ f.data.copy_(k)
34
+ elif self.blur_type == "motion":
35
+ k = MotionKernel(size=(self.kernel_size, self.kernel_size), intensity=self.std).kernelMatrix
36
+ k = torch.from_numpy(k)
37
+ self.k = k
38
+ for name, f in self.named_parameters():
39
+ f.data.copy_(k)
40
+
41
+ def update_weights(self, k):
42
+ if not torch.is_tensor(k):
43
+ k = torch.from_numpy(k).to(self.device)
44
+ for name, f in self.named_parameters():
45
+ f.data.copy_(k)
46
+
47
+ def get_kernel(self):
48
+ return self.k