rahul7star commited on
Commit
e0336bc
·
verified ·
1 Parent(s): defedb5

Upload 303 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .gitignore +18 -0
  3. .python-version +1 -0
  4. README.ja.md +377 -0
  5. README.md +269 -12
  6. base_hv_generate_video.py +936 -0
  7. base_wan_generate_video.py +1892 -0
  8. blissful_tuner/GIMMVFI.py +208 -0
  9. blissful_tuner/__init__.py +0 -0
  10. blissful_tuner/advanced_rope.py +112 -0
  11. blissful_tuner/blissful_args.py +131 -0
  12. blissful_tuner/blissful_settings.py +111 -0
  13. blissful_tuner/cfgzerostar.py +39 -0
  14. blissful_tuner/codeformer/LICENSE +15 -0
  15. blissful_tuner/codeformer/basicsr/VERSION +1 -0
  16. blissful_tuner/codeformer/basicsr/__init__.py +11 -0
  17. blissful_tuner/codeformer/basicsr/archs/__init__.py +25 -0
  18. blissful_tuner/codeformer/basicsr/archs/arcface_arch.py +245 -0
  19. blissful_tuner/codeformer/basicsr/archs/arch_util.py +318 -0
  20. blissful_tuner/codeformer/basicsr/archs/codeformer_arch.py +280 -0
  21. blissful_tuner/codeformer/basicsr/archs/rrdbnet_arch.py +119 -0
  22. blissful_tuner/codeformer/basicsr/archs/vgg_arch.py +161 -0
  23. blissful_tuner/codeformer/basicsr/archs/vqgan_arch.py +434 -0
  24. blissful_tuner/codeformer/basicsr/data/__init__.py +100 -0
  25. blissful_tuner/codeformer/basicsr/data/data_sampler.py +48 -0
  26. blissful_tuner/codeformer/basicsr/data/data_util.py +392 -0
  27. blissful_tuner/codeformer/basicsr/data/ffhq_blind_dataset.py +299 -0
  28. blissful_tuner/codeformer/basicsr/data/ffhq_blind_joint_dataset.py +324 -0
  29. blissful_tuner/codeformer/basicsr/data/gaussian_kernels.py +690 -0
  30. blissful_tuner/codeformer/basicsr/data/paired_image_dataset.py +101 -0
  31. blissful_tuner/codeformer/basicsr/data/prefetch_dataloader.py +125 -0
  32. blissful_tuner/codeformer/basicsr/data/transforms.py +165 -0
  33. blissful_tuner/codeformer/basicsr/losses/__init__.py +26 -0
  34. blissful_tuner/codeformer/basicsr/losses/loss_util.py +95 -0
  35. blissful_tuner/codeformer/basicsr/losses/losses.py +455 -0
  36. blissful_tuner/codeformer/basicsr/metrics/__init__.py +19 -0
  37. blissful_tuner/codeformer/basicsr/metrics/metric_util.py +45 -0
  38. blissful_tuner/codeformer/basicsr/metrics/psnr_ssim.py +128 -0
  39. blissful_tuner/codeformer/basicsr/models/__init__.py +30 -0
  40. blissful_tuner/codeformer/basicsr/models/base_model.py +322 -0
  41. blissful_tuner/codeformer/basicsr/models/codeformer_idx_model.py +220 -0
  42. blissful_tuner/codeformer/basicsr/models/codeformer_joint_model.py +350 -0
  43. blissful_tuner/codeformer/basicsr/models/codeformer_model.py +332 -0
  44. blissful_tuner/codeformer/basicsr/models/lr_scheduler.py +96 -0
  45. blissful_tuner/codeformer/basicsr/models/sr_model.py +209 -0
  46. blissful_tuner/codeformer/basicsr/models/vqgan_model.py +285 -0
  47. blissful_tuner/codeformer/basicsr/ops/__init__.py +0 -0
  48. blissful_tuner/codeformer/basicsr/ops/dcn/__init__.py +7 -0
  49. blissful_tuner/codeformer/basicsr/ops/dcn/deform_conv.py +377 -0
  50. blissful_tuner/codeformer/basicsr/ops/dcn/src/deform_conv_cuda.cpp +685 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ images/screenshot.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ .venv
3
+ venv/
4
+ logs/
5
+ uv.lock
6
+ .env
7
+ env/
8
+ outputs/
9
+ GIMM-VFI/
10
+ hunyuan/
11
+ temp_frames/
12
+ wan/wan2.1_i2v_480p_14B_bf16.safetensors
13
+ wan/wan2.1_t2v_14B_bf16.safetensors
14
+ wan/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth
15
+ wan/models_t5_umt5-xxl-enc-bf16.pth
16
+ triton-3.0.0-cp310-cp310-win_amd64.whl
17
+ wan/Wan2.1_VAE.pth
18
+ flash_attn-2.7.4+cu124torch2.5.1cxx11abiFALSE-cp310-cp310-win_amd64.whl
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.10
README.ja.md ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Musubi Tuner
2
+
3
+ [English](./README.md) | [日本語](./README.ja.md)
4
+
5
+ ## 目次
6
+
7
+ - [はじめに](#はじめに)
8
+ - [最近の更新](#最近の更新)
9
+ - [リリースについて](#リリースについて)
10
+ - [概要](#概要)
11
+ - [ハードウェア要件](#ハードウェア要件)
12
+ - [特徴](#特徴)
13
+ - [インストール](#インストール)
14
+ - [モデルのダウンロード](#モデルのダウンロード)
15
+ - [HunyuanVideoの公式モデルを使う](#HunyuanVideoの公式モデルを使う)
16
+ - [Text EncoderにComfyUI提供のモデルを使う](#Text-EncoderにComfyUI提供のモデルを使う)
17
+ - [使い方](#使い方)
18
+ - [データセット設定](#データセット設定)
19
+ - [latentの事前キャッシュ](#latentの事前キャッシュ)
20
+ - [Text Encoder出力の事前キャッシュ](#Text-Encoder出力の事前キャッシュ)
21
+ - [学習](#学習)
22
+ - [LoRAの重みのマージ](#LoRAの重みのマージ)
23
+ - [推論](#推論)
24
+ - [LoRAの形式の変換](#LoRAの形式の変換)
25
+ - [その他](#その他)
26
+ - [SageAttentionのインストール方法](#SageAttentionのインストール方法)
27
+ - [免責事項](#免責事項)
28
+ - [コントリビューションについて](#コントリビューションについて)
29
+ - [ライセンス](#ライセンス)
30
+
31
+ ## はじめに
32
+
33
+ このリポジトリは、HunyuanVideoのLoRA学習用のコマンドラインツールです。このリポジトリは非公式であり、公式のHunyuanVideoリポジトリとは関係ありません。
34
+
35
+ *リポジトリは開発中です。*
36
+
37
+ ### 最近の更新
38
+
39
+ - 2025/01/20
40
+ - uv によるインストール手順を試験的に追加しました。PR [#51](https://github.com/kohya-ss/musubi-tuner/pull/51) bmaltais 氏に感謝いたします。ただ、設定等は詰められていないため、フィードバックを歓迎します。
41
+ - 高度な設定に、[TensorBoard形式のログの保存と参照](./docs/advanced_config.md#save-and-view-logs-in-tensorboard-format--tensorboard形式のログの保存と参照)を追加しました。
42
+
43
+ - 2025/01/19
44
+ - latentとText Encoder出力の事前キャッシュ時に、データセットに含まれないキャッシュファイルを自動で消去するようにしました。これにより予期しないファイルが残り、学習に使用されてしまう問題が解消されます。
45
+ - `--keep_cache`で今までと同様にキャッシュファイルを残すことができます。
46
+ - Text Encoder出力の事前キャッシュ時に、`--skip_existing`を指定すると正しく動作しない問題を修正しました。
47
+
48
+ - 2025/01/18
49
+ - `hv_generate_video.py`でvideo2videoの推論が可能になりました。詳細は[推論](#推論)を参照してください。
50
+
51
+ - 2025/01/16
52
+ - LoRAの重みをマージするスクリプト、`merge_lora.py`が追加されました。PR [#37](https://github.com/kohya-ss/musubi-tuner/pull/37) kaykyr氏に感謝いたします。詳細は[LoRAの重みのマージ](#LoRAの重みのマージ)を参照してください。
53
+ - サンプルの学習設定を、学習率を2e-4に、`--timestep_sampling`を`shift`に、`--discrete_flow_shift`を7.0に変更しました。より高速な学習が期待されます。詳細は[学習](#学習)を参照してください。
54
+
55
+ - 2025/01/14
56
+ - `hv_generate_video.py`に、LoRAマージ後のDiTモデルを保存する`--save_merged_model`オプションを暫定的に追加しました。詳細は[推論](#推論)を参照してください。
57
+
58
+ - 2025/01/13
59
+ - 学習中のサンプル画像(動画)がぼやける現象に対応するため、サンプル生成時の設定を変更しました。詳細は[こちら](./docs/sampling_during_training.md)をご参照ください。
60
+ - 推論時にdiscrete flow shiftとguidance scaleを正しく設定する必要がありますが、学習時の設定がそのまま使われていたため、この事象が発生していました。デフォルト値を設定したため、改善されると思われます。また`--fs`でdiscrete flow shiftを、`--g`でguidance scaleを指定できます。
61
+
62
+ ### リリースについて
63
+
64
+ Musubi Tunerの解説記事執筆や、関連ツールの開発に取り組んでくださる方々に感謝いたします。このプロジェクトは開発中のため、互換性のない変更や機能追加が起きる可能性があります。想定外の互換性問題を避けるため、参照用として[リリース](https://github.com/kohya-ss/musubi-tuner/releases)をお使いください。
65
+
66
+ 最新のリリースとバージョン履歴は[リリースページ](https://github.com/kohya-ss/musubi-tuner/releases)で確認できます。
67
+
68
+ ## 概要
69
+
70
+ ### ハードウェア要件
71
+
72
+ - VRAM: 静止画での学習は12GB以上推奨、動画での学習は24GB以上推奨。
73
+ - *解像度等の学習設定により異なります。*12GB��は解像度 960x544 以下とし、`--blocks_to_swap`、`--fp8_llm`等の省メモリオプションを使用してください。
74
+ - メインメモリ: 64GB以上を推奨、32GB+スワップで動作するかもしれませんが、未検証です。
75
+
76
+ ### 特徴
77
+
78
+ - 省メモリに特化
79
+ - Windows対応(Linuxでの動作報告もあります)
80
+ - マルチGPUには対応していません
81
+
82
+ ## インストール
83
+
84
+ ### pipによるインストール
85
+
86
+ Python 3.10以上を使用してください(3.10で動作確認済み)。
87
+
88
+ 適当な仮想環境を作成し、ご利用のCUDAバージョンに合わせたPyTorchとtorchvisionをインストールしてください。
89
+
90
+ PyTorchはバージョン2.5.1以上を使用してください([補足](#PyTorchのバージョンについて))。
91
+
92
+ ```bash
93
+ pip install torch torchvision --index-url https://download.pytorch.org/whl/cu124
94
+ ```
95
+
96
+ 以下のコマンドを使用して、必要な依存関係をインストールします。
97
+
98
+ ```bash
99
+ pip install -r requirements.txt
100
+ ```
101
+
102
+ オプションとして、FlashAttention、SageAttention(推論にのみ使用、インストール方法は[こちら](#SageAttentionのインストール方法)を参照)を使用できます。
103
+
104
+ また、`ascii-magic`(データセットの確認に使用)、`matplotlib`(timestepsの可視化に使用)、`tensorboard`(学習ログの記録に使用)を必要に応じてインストールしてください。
105
+
106
+ ```bash
107
+ pip install ascii-magic matplotlib tensorboard
108
+ ```
109
+ ### uvによるインストール
110
+
111
+ uvを使用してインストールすることもできますが、uvによるインストールは試験的なものです。フィードバックを歓迎します。
112
+
113
+ #### Linux/MacOS
114
+
115
+ ```sh
116
+ curl -LsSf https://astral.sh/uv/install.sh | sh
117
+ ```
118
+
119
+ 表示される指示に従い、pathを設定してください。
120
+
121
+ #### Windows
122
+
123
+ ```powershell
124
+ powershell -c "irm https://astral.sh/uv/install.ps1 | iex"
125
+ ```
126
+
127
+ 表示される指示に従い、PATHを設定するか、この時点でシステムを再起動してください。
128
+
129
+ ## モデルのダウンロード
130
+
131
+ 以下のいずれかの方法で、モデルをダウンロードしてください。
132
+
133
+ ### HunyuanVideoの公式モデルを使う
134
+
135
+ [公式のREADME](https://github.com/Tencent/HunyuanVideo/blob/main/ckpts/README.md)を参考にダウンロードし、任意のディレクトリに以下のように配置します。
136
+
137
+ ```
138
+ ckpts
139
+ ├──hunyuan-video-t2v-720p
140
+ │ ├──transformers
141
+ │ ├──vae
142
+ ├──text_encoder
143
+ ├──text_encoder_2
144
+ ├──...
145
+ ```
146
+
147
+ ### Text EncoderにComfyUI提供のモデルを使う
148
+
149
+ こちらの方法の方がより簡単です。DiTとVAEのモデルはHumyuanVideoのものを使用します。
150
+
151
+ https://huggingface.co/tencent/HunyuanVideo/tree/main/hunyuan-video-t2v-720p/transformers から、[mp_rank_00_model_states.pt](https://huggingface.co/tencent/HunyuanVideo/resolve/main/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt) をダウンロードし、任意のディレクトリに配置します。
152
+
153
+ (同じページにfp8のモデルもありますが、未検証です。)
154
+
155
+ `--fp8_base`を指定して学習する場合は、`mp_rank_00_model_states.pt`の代わりに、[こちら](https://huggingface.co/kohya-ss/HunyuanVideo-fp8_e4m3fn-unofficial)の`mp_rank_00_model_states_fp8.safetensors`を使用可能です。(このファイルは非公式のもので、重みを単純にfloat8_e4m3fnに変換したものです。)
156
+
157
+ また、https://huggingface.co/tencent/HunyuanVideo/tree/main/hunyuan-video-t2v-720p/vae から、[pytorch_model.pt](https://huggingface.co/tencent/HunyuanVideo/resolve/main/hunyuan-video-t2v-720p/vae/pytorch_model.pt) をダウンロードし、任意のディレクトリに配置します。
158
+
159
+ Text EncoderにはComfyUI提供のモデルを使用させていただきます。[ComyUIのページ](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)を参考に、https://huggingface.co/Comfy-Org/HunyuanVideo_repackaged/tree/main/split_files/text_encoders から、llava_llama3_fp16.safetensors (Text Encoder 1、LLM)と、clip_l.safetensors (Text Encoder 2、CLIP)をダウンロードし、任意のディレクトリに配置します。
160
+
161
+ (同じページにfp8のLLMモデルもありますが、動作未検証です。)
162
+
163
+ ## 使い方
164
+
165
+ ### データセット設定
166
+
167
+ [こちら](./dataset/dataset_config.md)を参照してください。
168
+
169
+ ### latentの事前キャッシュ
170
+
171
+ latentの事前キャッシュは必須です。以下のコマンドを使用して、事前キャッシュを作成してください。(pipによるインストールの場合)
172
+
173
+ ```bash
174
+ python cache_latents.py --dataset_config path/to/toml --vae path/to/ckpts/hunyuan-video-t2v-720p/vae/pytorch_model.pt --vae_chunk_size 32 --vae_tiling
175
+ ```
176
+
177
+ uvでインストールした場合は、`uv run python cache_latents.py ...`のように、`uv run`を先頭につけてください。以下のコマンドも同様です。
178
+
179
+ その他のオプションは`python cache_latents.py --help`で確認できます。
180
+
181
+ VRAMが足りない場合は、`--vae_spatial_tile_sample_min_size`を128程度に減らし、`--batch_size`を小さくしてください。
182
+
183
+ `--debug_mode image` を指定するとデータセットの画像とキャプションが新規ウィンドウに表示されます。`--debug_mode console`でコンソールに表示されます(`ascii-magic`が必要)。
184
+
185
+ デフォルトではデータセットに含まれないキャッシュファイルは自動的に削除されます。`--keep_cache`を指定すると、キャッシュファイルを残すことができます。
186
+
187
+ ### Text Encoder出力の事前キャッシュ
188
+
189
+ Text Encoder出力の事前キャッシュは必須です。以下のコマンドを使用して、事前キャッシュを作成してください。
190
+
191
+ ```bash
192
+ python cache_text_encoder_outputs.py --dataset_config path/to/toml --text_encoder1 path/to/ckpts/text_encoder --text_encoder2 path/to/ckpts/text_encoder_2 --batch_size 16
193
+ ```
194
+
195
+ その他のオプションは`python cache_text_encoder_outputs.py --help`で確認できます。
196
+
197
+ `--batch_size`はVRAMに合わせて調整してください。
198
+
199
+ VRAMが足りない場合(16GB程度未満の場合)は、`--fp8_llm`を指定して、fp8でLLMを実行してください。
200
+
201
+ デフォルトではデータセットに含まれないキャッシュファイルは自動的に削除されます。`--keep_cache`を指定すると、キャッシュファイルを残すことができます。
202
+
203
+ ### 学習
204
+
205
+ 以下のコマンドを使用して、学習を開始します(実際には一行で入力してください)。
206
+
207
+ ```bash
208
+ accelerate launch --num_cpu_threads_per_process 1 --mixed_precision bf16 hv_train_network.py
209
+ --dit path/to/ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt
210
+ --dataset_config path/to/toml --sdpa --mixed_precision bf16 --fp8_base
211
+ --optimizer_type adamw8bit --learning_rate 2e-4 --gradient_checkpointing
212
+ --max_data_loader_n_workers 2 --persistent_data_loader_workers
213
+ --network_module networks.lora --network_dim 32
214
+ --timestep_sampling shift --discrete_flow_shift 7.0
215
+ --max_train_epochs 16 --save_every_n_epochs 1 --seed 42
216
+ --output_dir path/to/output_dir --output_name name-of-lora
217
+ ```
218
+
219
+ __更新__:サンプルの学習率を1e-3から2e-4に、`--timestep_sampling`を`sigmoid`から`shift`に、`--discrete_flow_shift`を1.0から7.0に変更しました。より高速な学習が期待されます。ディテールが甘くなる場合は、discrete flow shiftを3.0程度に下げてみてください。
220
+
221
+ ただ、適切な学習率、学習ステップ数、timestepsの分布、loss weightingなどのパラメータは、以前として不明な点が数多くあります。情報提供をお待ちしています。
222
+
223
+ その他のオプションは`python hv_train_network.py --help`で確認できます(ただし多くのオプションは動作未確認です)。
224
+
225
+ `--fp8_base`を指定すると、DiTがfp8で学習されます。未指定時はmixed precisionのデータ型が使用されます。fp8は大きく消費メモリを削減できますが、品質は低下する可能性があります。`--fp8_base`を指定しない場合はVRAM 24GB以上を推奨します。また必要に応じて`--blocks_to_swap`を使用してください。
226
+
227
+ VRAMが足りない場合は、`--blocks_to_swap`を指定して、一部のブロックをCPUにオフロードしてください。最大36が指定できます。
228
+
229
+ (block swapのアイデアは2kpr氏の実装に基づくものです。2kpr氏にあらためて感謝します。)
230
+
231
+ `--sdpa`でPyTorchのscaled dot product attentionを使用します。`--flash_attn`で[FlashAttention]:(https://github.com/Dao-AILab/flash-attention)を使用します。`--xformers`でxformersの利用も可能ですが、xformersを使う場合は`--split_attn`を指定してください。`--sage_attn`でSageAttentionを使用しますが、SageAttentionは現時点では学習に未対応のため、正しく動作しません。
232
+
233
+ `--split_attn`を指定すると、attentionを分割して処理します。速度が多少低下しますが、VRAM使用量はわずかに減ります。
234
+
235
+ 学習されるLoRAの形式は、`sd-scripts`と同じです。
236
+
237
+ `--show_timesteps`に`image`(`matplotlib`が必要)または`console`を指定すると、学習時のtimestepsの分布とtimestepsごとのloss weightingが確認できます。
238
+
239
+ 学習時のログの記録が可能です。[TensorBoard形式のログの保存と参照](./docs/advanced_config.md#save-and-view-logs-in-tensorboard-format--tensorboard形式のログの保存と参照)を参照してください。
240
+
241
+ 学習中のサンプル画像生成については、[こちらのドキュメント](./docs/sampling_during_training.md)を参照してください。���の他の高度な設定については[こちらのドキュメント](./docs/advanced_config.md)を参照してください。
242
+
243
+ ### LoRAの重みのマージ
244
+
245
+ ```bash
246
+ python merge_lora.py \
247
+ --dit path/to/ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt \
248
+ --lora_weight path/to/lora.safetensors \
249
+ --save_merged_model path/to/merged_model.safetensors \
250
+ --device cpu \
251
+ --lora_multiplier 1.0
252
+ ```
253
+
254
+ `--device`には計算を行うデバイス(`cpu`または`cuda`等)を指定してください。`cuda`を指定すると計算が高速化されます。
255
+
256
+ `--lora_weight`にはマージするLoRAの重みを、`--lora_multiplier`にはLoRAの重みの係数を、それぞれ指定してください。複数個が指定可能で、両者の数は一致させてください。
257
+
258
+ ### 推論
259
+
260
+ 以下のコマンドを使用して動画を生成します。
261
+
262
+ ```bash
263
+ python hv_generate_video.py --fp8 --video_size 544 960 --video_length 5 --infer_steps 30
264
+ --prompt "A cat walks on the grass, realistic style." --save_path path/to/save/dir --output_type both
265
+ --dit path/to/ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt --attn_mode sdpa --split_attn
266
+ --vae path/to/ckpts/hunyuan-video-t2v-720p/vae/pytorch_model.pt
267
+ --vae_chunk_size 32 --vae_spatial_tile_sample_min_size 128
268
+ --text_encoder1 path/to/ckpts/text_encoder
269
+ --text_encoder2 path/to/ckpts/text_encoder_2
270
+ --seed 1234 --lora_multiplier 1.0 --lora_weight path/to/lora.safetensors
271
+ ```
272
+
273
+ その他のオプションは`python hv_generate_video.py --help`で確認できます。
274
+
275
+ `--fp8`を指定すると、DiTがfp8で推論されます。fp8は大きく消費メモリを削減できますが、品質は低下する可能性があります。
276
+
277
+ VRAMが足りない場合は、`--blocks_to_swap`を指定して、一部のブロックをCPUにオフロードしてください。最大38が指定できます。
278
+
279
+ `--attn_mode`には`flash`、`torch`、`sageattn`、`xformers`または`sdpa`(`torch`指定時と同じ)のいずれかを指定してください。それぞれFlashAttention、scaled dot product attention、SageAttention、xformersに対応します。デフォルトは`torch`です。SageAttentionはVRAMの削減に有効です。
280
+
281
+ `--split_attn`を指定すると、attentionを分割して処理します。SageAttention利用時で10%程度の高速化が見込まれます。
282
+
283
+ `--output_type`には`both`、`latent`、`video`、`images`のいずれかを指定してください。`both`はlatentと動画の両方を出力します。VAEでOut of Memoryエラーが発生する場合に備えて、`both`を指定することをお勧めします。`--latent_path`に保存されたlatentを指定し、`--output_type video` (または`images`)としてスクリプトを実行すると、VAEのdecodeのみを行えます。
284
+
285
+ `--seed`は省略可能です。指定しない場合はランダムなシードが使用されます。
286
+
287
+ `--video_length`は「4の倍数+1」を指定してください。
288
+
289
+ `--flow_shift`にタイムステップのシフト値(discrete flow shift)を指定可能です。省略時のデフォルト値は7.0で、これは推論ステップ数が50の時の推奨値です。HunyuanVideoの論文では、ステップ数50の場合は7.0、ステップ数20未満(10など)で17.0が推奨されています。
290
+
291
+ `--video_path`に読み込む動画を指定すると、video2videoの推論が可能です。動画ファイルを指定するか、複数の画像ファイルが入ったディレクトリを指定してください(画像ファイルはファイル名でソートされ、各フレームとして用いられます)。`--video_length`よりも短い動画を指定するとエラーになります。`--strength`で強度を指定できます。0~1.0で指定でき、大きいほど元の動画からの変化が大きくなります。
292
+
293
+ なおvideo2video推論の処理は実験的なものです。
294
+
295
+ `--save_merged_model`オプションで、LoRAマージ後のDiTモデルを保存できます。`--save_merged_model path/to/merged_model.safetensors`のように指定してください。なおこのオプションを指定すると推論は行われません。
296
+
297
+ ### LoRAの形式の変換
298
+
299
+ ComfyUIで使用可能な形式(Diffusion-pipeと思われる)への変換は以下のコマンドで行えます。
300
+
301
+ ```bash
302
+ python convert_lora.py --input path/to/musubi_lora.safetensors --output path/to/another_format.safetensors --target other
303
+ ```
304
+
305
+ `--input`と`--output`はそれぞれ入力と出力のファイルパスを指定してください。
306
+
307
+ `--target`には`other`を指定してください。`default`を指定すると、他の形式から当リポジトリの形式に変換できます。
308
+
309
+ ## その他
310
+
311
+ ### SageAttentionのインストール方法
312
+
313
+ sdbds氏によるWindows対応のSageAttentionのwheelが https://github.com/sdbds/SageAttention-for-windows で公開されています��triton をインストールし、Python、PyTorch、CUDAのバージョンが一致する場合は、[Releases](https://github.com/sdbds/SageAttention-for-windows/releases)からビルド済みwheelをダウンロードしてインストールすることが可能です。sdbds氏に感謝します。
314
+
315
+ 参考までに、以下は、SageAttentionをビルドしインストールするための簡単な手順です。Microsoft Visual C++ 再頒布可能パッケージを最新にする必要があるかもしれません。
316
+
317
+ 1. Pythonのバージョンに応じたtriton 3.1.0のwhellを[こちら](https://github.com/woct0rdho/triton-windows/releases/tag/v3.1.0-windows.post5)からダウンロードしてインストールします。
318
+
319
+ 2. Microsoft Visual Studio 2022かBuild Tools for Visual Studio 2022を、C++のビルドができるよう設定し、インストールします。(上のRedditの投稿を参照してください)。
320
+
321
+ 3. 任意のフォルダにSageAttentionのリポジトリをクローンします。
322
+ ```shell
323
+ git clone https://github.com/thu-ml/SageAttention.git
324
+ ```
325
+
326
+ なお `git clone https://github.com/sdbds/SageAttention-for-windows.git` で、前述のsdbds氏のリポジトリを使用することで、手順4.を省略できます。
327
+
328
+ 4. `SageAttention/csrc`フォルダ内の`math.cuh`を開き、71行目と146行目の `ushort` を `unsigned short` に変更して保存します。
329
+
330
+ 5. スタートメニューから Visual Studio 2022 内の `x64 Native Tools Command Prompt for VS 2022` を選択してコマンドプロンプトを開きます。
331
+
332
+ 6. venvを有効にし、SageAttentionのフォルダに移動して以下のコマンドを実行します。DISTUTILSが設定されていない、のようなエラーが出た場合は `set DISTUTILS_USE_SDK=1`としてから再度実行してください。
333
+ ```shell
334
+ python setup.py install
335
+ ```
336
+
337
+ 以上でSageAttentionのインストールが完了です。
338
+
339
+ ### PyTorchのバージョンについて
340
+
341
+ `--attn_mode`に`torch`を指定する場合、2.5.1以降のPyTorchを使用してください(それより前のバージョンでは生成される動画が真っ黒になるようです)。
342
+
343
+ 古いバージョンを使う場合、xformersやSageAttentionを使用してください。
344
+
345
+ ## 免責事項
346
+
347
+ このリポジトリは非公式であり、公式のHunyuanVideoリポジトリとは関係ありません。また、このリポジトリは開発中で、実験的なものです。テストおよびフィードバックを歓迎しますが、以下の点にご注意ください:
348
+
349
+ - 実際の稼働環境での動作を意図したものではありません
350
+ - 機能やAPIは予告なく変更されることがあります
351
+ - いくつもの機能が未検証です
352
+ - 動画学習機能はまだ開発中です
353
+
354
+ 問題やバグについては、以下の情報とともにIssueを作成してください:
355
+
356
+ - 問題の詳細な説明
357
+ - 再現手順
358
+ - 環境の詳細(OS、GPU、VRAM、Pythonバージョンなど)
359
+ - 関連するエラーメッセージやログ
360
+
361
+ ## コントリビューションについて
362
+
363
+ コントリビューションを歓迎します。ただし、以下にご注意ください:
364
+
365
+ - メンテナーのリソースが限られているため、PRのレビューやマージには時間がかかる場合があります
366
+ - 大きな変更に取り組む前には、議論のためのIssueを作成してください
367
+ - PRに関して:
368
+ - 変更は焦点を絞り、適度なサイズにしてください
369
+ - 明確な説明をお願いします
370
+ - 既存のコードスタイルに従ってください
371
+ - ドキュメントが更新されていることを確認してください
372
+
373
+ ## ライセンス
374
+
375
+ `hunyuan_model`ディレクトリ以下のコードは、[HunyuanVideo](https://github.com/Tencent/HunyuanVideo)のコードを一部改変して使用しているため、そちらのライセンスに従います。
376
+
377
+ 他のコードはApache License 2.0に従います。一部Diffusersのコードをコピー、改変して使用しています。
README.md CHANGED
@@ -1,12 +1,269 @@
1
- ---
2
- title: Framepack H111
3
- emoji: 🌖
4
- colorFrom: pink
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 5.31.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ![GUI Screenshot](images/screenshot.png)
2
+
3
+ # Recent update
4
+ 5/25/2025
5
+ Enable full intermediate previews for framepack tab, some change to framepack extension with image input logic.
6
+ 5/24/2025
7
+ Batch images from folder now available for framepack. Save only extension option and experimental start image for f1 in framepack extension tab.
8
+ 5/23/2025
9
+ Added ability to use the normal framepack model with endframe support in the framepack extension tab. Support additional bucket sizes.
10
+ 5/18/2025
11
+ Add video extension tab. Currently only works with f1 model. Full credit to @pfqt and @chaojie for their amazing work!
12
+
13
+ # H1111
14
+
15
+ This is a GUI for tech wizard kohya-ss's musubi tuner's inference script.
16
+ https://github.com/kohya-ss/musubi-tuner
17
+
18
+ It allows inference with these models:
19
+ FramePack
20
+ Hunyuan-t2v
21
+ Hunyuan-i2v
22
+ Hunyuan-v2v
23
+ WanX-t2v
24
+ WanX-i2v
25
+ WanX-v2v
26
+ SkyReels-i2v
27
+ SkyReels-t2v
28
+
29
+ I have mostly been workiing on the framepack tab and the WanX-i2v tab. They are the best to use right now. WanX-i2v is used for skyreels v2 and the fun control models.
30
+
31
+ This supports queuing multiple different jobs if you open 2+ browser tabs and use the same model.
32
+
33
+ If you are running out of vram use more block swapping. Using FP8 scaled is also a decent option to lower memory usage, select fp8 and fp8 scaled to use it. Scaled fp8 tries to duplicate the important parts of the model from FP16. Sage attention is the fastest/lowest vram but difficult to install in windows.
34
+
35
+ Best quality will be obtained with only enabling block swapping and using the fp16 model with sdpa attention. You can speed things up with cfg skip, fp8 scaled, slg skip is small speedup, sage attention is fastest but all speedups come with quality degradations. I designed this to try to focus on quality over speed.
36
+
37
+ If you are using a lora that you didn't train with musubi you need to drag it to the convert lora tab and convert it to the default format. It should spit it out into the /lora folder.
38
+
39
+ If you need additional installation instructions or information create an issue and I will try to help. Also there are alot of settings notes on the musubi github linked above.
40
+
41
+ For torch 2.7.0 and windows installation try:
42
+ pip install typing-extensions
43
+ pip install torch==2.7.0+cu128 torchvision==0.22.0+cu128 --index-url https://download.pytorch.org/whl/cu128
44
+ pip install -r requirementsTorch27.txt
45
+
46
+ ## To Use FramePack
47
+
48
+
49
+ download these 5 files from https://huggingface.co/maybleMyers/framepack_h1111 and put them in a subfolder named hunyuan (H1111/hunyuan), or reference where they are in the gui if you have already aquired them.
50
+
51
+ FramePackI2V_HY_bf16.safetensors or FramePack_F1_I2V_HY_20250503.safetensors for F1
52
+
53
+ clip_l.safetensors
54
+
55
+ llava_llama3_fp16.safetensors
56
+
57
+ model.safetensors
58
+
59
+ pytorch_model.pt
60
+
61
+ Lora trained with musubi tuner's framepack training confirmed to work great. Normal lora trained for hunyuan kinda suck. Use a lot of block swap this is a different back end than the official repo. If you select fp8 and fp8 scaled it will all fit on a 24gb gpu for fastest speed, about 3s/it or 1:17 per second of video w/ a 4090. Best quality will be obtained with just block swapping/sdpa attention/full model though.
62
+
63
+ Put loras in a /lora subfolder, if not trained with musubi you need to convert them.
64
+
65
+ Only unipc is supported for now. Sage attn is experimental. When using the F1 model not all options available for the original framepack model will work, like endframe and sectional images.
66
+
67
+ Here is an example prompt for a 5 second video with 4 sections using sectional prompting, also supports longer videos with indexes ie 0-2 ;;;3-5 etc:
68
+
69
+ 0:A cinematic video showcases a cute blue penguin wearing sunglasses. The penguin runs quickly into mcdonalds.;;;1:The penguin runs quickly into mcdonalds and jumps up on a table and starts eating his food. The penguin's name is Piplup he is a famous Pokemon actor. The video is a fast action sequence animation showing the penguin running into a mcdonalds an jumping up onto a table.;;;2:The penguin is seated at a table and is enjoying his happy meal. The penguin's name is Piplup he is a famous Pokemon actor. The video is a fast action sequence animation showing the penguin running into a mcdonalds and jumping up onto a table.;;;3:The penguin is seated at a table and is happily enjoying his happy meal. The penguin's name is Piplup he is a famous Pokemon actor. The penguin flexes his huge arm muscles at the end of the video.
70
+
71
+ I have added support for 4 sectional images during inference. It works best when the images are close together. Refer to the screen shot for an example of a working 5 second video.
72
+
73
+ For more details on using framepack with musubi go here https://github.com/kohya-ss/musubi-tuner/blob/main/docs/framepack.md
74
+
75
+ Fastest speed will be achieved with fp8 and fp8 scaled, then you can reduce block swapping to your memory constraints. (leave about 1gb free)
76
+
77
+ Framepack Extension tab is still a work in progress.
78
+ Thanks to @pftq https://github.com/pftq and @chaojie https://github.com/chaojie for their work on the extension logics.
79
+
80
+ ## To Use the new Skyreels-V2 models
81
+
82
+ I have provided these 2 at https://huggingface.co/maybleMyers/wan_files_for_h1111
83
+
84
+ SkyReels-V2-I2V-14B-720P-FP16.safetensors
85
+ SkyReels-V2-I2V-14B-540P-FP16.safetensors
86
+
87
+ You can just drop them into the wan folder and use them in the WanX-i2v tab. Skyreels-V2 is a fine tune from Wan2.1.
88
+ If you have download the kijai variants the will not work because he added extra keys to the model.
89
+
90
+ ## To Use WanX
91
+
92
+ To use wanX download these and toss them in the wan subfolder:
93
+ Download the T5 `models_t5_umt5-xxl-enc-bf16.pth`, vae `Wan2.1_VAE.pth` and CLIP `models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth` from the following page: https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P/tree/main
94
+
95
+ Download the DiT weights from the following page: https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/tree/main/split_files/diffusion_models
96
+ ie : wan2.1_i2v_720p_14B_fp16.safetensors
97
+
98
+ For the fun control option in WanX-i2v I recommend the fp16 weights here: https://huggingface.co/maybleMyers/wan_files_for_h1111/tree/main
99
+ Wan2.1-Fun-14B-Control_fp16.safetensors
100
+
101
+ git pull to update the installation
102
+ pip install -r requirements.txt
103
+
104
+ I have tested the 14B i2v and t2v models so far to be working
105
+
106
+ ## Requirements
107
+
108
+ - Python 3.10
109
+ - CUDA 12.4
110
+
111
+ ## Basic Installation (Linux)
112
+
113
+ Tested on ubuntu 24
114
+
115
+ to update navigate to H1111 and git pull
116
+
117
+ ```powershell
118
+ git clone https://github.com/maybleMyers/H1111
119
+ cd H1111
120
+ python -m venv env
121
+ #(if you have another version of python do python3.10 -m venv env after you install it with sudo apt install python3.10 python3.10-venv python3.10-distutils)
122
+ source env/bin/activate
123
+ pip install torch==2.5.1 torchvision --index-url https://download.pytorch.org/whl/cu124
124
+ pip install -r requirements.txt
125
+ pip install flash-attn --no-build-isolation
126
+ pip install sageattention==1.0.6
127
+ might need python3.10-dev as well for sage attention to work
128
+
129
+ ```
130
+
131
+ run with
132
+ source env/bin/activate
133
+ python h1111.py
134
+
135
+ for GPU1
136
+ CUDA_VISIBLE_DEVICES=1 python h1111.py
137
+
138
+ ## Basic Installation (Windows)
139
+
140
+
141
+
142
+ First, open PowerShell and navigate to your desired installation directory. Then run these commands:
143
+
144
+ ```powershell
145
+ git clone https://github.com/maybleMyers/H1111
146
+ cd H1111
147
+ python -m venv env
148
+ ./env/scripts/activate
149
+ pip install torch==2.5.1 torchvision --index-url https://download.pytorch.org/whl/cu124
150
+ pip install -r requirements.txt
151
+
152
+ ```
153
+
154
+ ## To run
155
+
156
+ ```
157
+ env/scripts/activate
158
+ python h1111.py
159
+ ```
160
+
161
+ open 127.0.0.1:7860 in a browser
162
+
163
+ You can set cuda device to 1,2,3,4,5,6,7 etc in the env once activated in a separate terminal to run unlimited copies at once if you have another gpu.
164
+ ie for linux on the second gpu: CUDA_VISIBLE_DEVICES=1 python h1111.py
165
+
166
+ ## full changlog
167
+ 5/25/2025
168
+ Enable full intermediate previews for framepack tab, some change to framepack extension with image input logic.
169
+ 5/24/2025
170
+ Batch images from folder now available for framepack. Save only extension option and experimental start image for f1 in framepack extension tab.
171
+ 5/23/2025
172
+ Added ability to use the normal framepack model with endframe support in the framepack extension tab. Support additional bucket sizes.
173
+ 5/18/2025
174
+ Add video extension tab. Currently only works with f1 model. Full credit to @pfqt and @chaojie for their amazing work!
175
+ 5/12/2025
176
+ Add skip button to framepack.
177
+ 5/9/2025
178
+ Add testing branch for framepack F1 end image, kinda glitchygo https://github.com/maybleMyers/H1111/tree/f1_end
179
+ 5/5/2025
180
+ Update an experimental hunyuan to framepack convert lora option in the convert lora tab.
181
+ Add tea cache to frame pack.
182
+ 5/3/2025
183
+ Add support for framepack F1! download from https://huggingface.co/maybleMyers/wan_files_for_h1111/blob/main/FramePack_F1_I2V_HY_20250503.safetensors put it in your hunyuan folder. You might need to reinstall reqs "pip install -r requirements.txt"
184
+ Add support for Wan2.1 i2v-14B-FC-1.1. It is a fun control model and is very good. Use it in the WanX-i2v tab and make sure to select the task i2v-14B-FC-1.1 at the bottom of the page. Download the weights from https://huggingface.co/maybleMyers/wan_files_for_h1111
185
+ 4/30/2025
186
+ Previews for framepack.
187
+ 4/29/2025
188
+ Add initial preview support to the wanX-i2v tab based. If you want to use them use the preview branch. Thanks to Sarania.
189
+ Wan2.1-Fun-V1.1-14B-InP-FP16.safetensors is available at https://huggingface.co/maybleMyers/wan_files_for_h1111
190
+ Fix bug in hunyuan-t2v not loading lora.
191
+ 4/26/2025
192
+ Add SkyReels-V2-I2V-14B-720P-FP16.safetensors to supported models.
193
+ Added alot better options for Framepack including working sectional images, Thanks to kohya!
194
+ 4/25/2025
195
+ Framepack backend updates for better LoRa support for LoRa's trained with musubi tuner. Also better weighting options.
196
+ 4/24/2025
197
+ Update FramePack backend to musubi backend instead of original. Offers much improved speed and some quality improvements.
198
+ Add support for torch 2.7.0 + cuda 12.8
199
+ 4/18/2025
200
+ Add initial support for FramePack. https://github.com/lllyasviel/FramePack
201
+ 4/15/2025
202
+ Add much improved functionality for the wan fun control model. Added strength imrpovements and dropoff code to choose when to apply the control video. Thanks wordbrew.
203
+ 4/3/2025
204
+ Add support for hunyuan i2v model. Download the clip vision from https://huggingface.co/maybleMyers/H1111_Hunyuan_i2v And download the official model from hunyuan's website and rename it to mp_rank_00_model_states_i2v.pt https://huggingface.co/tencent/HunyuanVideo-I2V/tree/main/hunyuan-video-i2v-720p/transformers add both to your hunyuan folder.
205
+ 3/29/2025
206
+ Added support for fun models! download dit from https://huggingface.co/alibaba-pai/Wan2.1-Fun-14B-Control and specify correct task type and dit location. I renamed it from diffusion_pytorch_model to Wan2.1-Fun-14B-control. Works in the normal WanX-i2v tab when you select the control option at the bottom of the page.
207
+ 3/23/2025
208
+ Added Wanx cfg skip functionality to skip cfg guidance during inference for faster generations but less following of the prompt
209
+ 3/22/2025
210
+ Added WanX-i2v end frame functionality
211
+ 3/20/2025
212
+ Added WanX-v2v functionality.
213
+ 3/18/2025
214
+ Added Skip Layer Guidance for WanX-i2v.
215
+ 3/13/2025
216
+ Added extend video functionality to WanX-i2v. It kind of works .
217
+ 3/12/2025
218
+ Added ability to send the last frame of a video to the input in WanX-i2v. Also you can now use this to extend the video. You can do multiple batches at each step and pick the best extended video then generate an even longer one.
219
+ 3/9/2025
220
+ Added batching ability for a folder full of images in WanX-i2v tab. Added flash attn for windows prebuilt wheel.
221
+ 3/8/2025
222
+ Added support for wan lora's. Remember to convert them first in the convert lora tab.
223
+ 3/5/2025
224
+ Added ability to batch a folder of images with skyreels i2v, so you can make a video with every image in a folder.
225
+ 3/2/2025
226
+ Added initial support for wanX-2.1 Image to Video and Text to Video inference.
227
+ 3/1/2025
228
+ Added support for Skyreels Video to Video and Text to Video.
229
+ 2/23/2025
230
+ Added initial support for skyreels-V1 using musubi's skyreel implementation. (thanks sdbds)
231
+ download models from https://huggingface.co/Kijai/SkyReels-V1-Hunyuan_comfy and add them to your hunyuan folder
232
+ skyreels_hunyuan_i2v_bf16.safetensors
233
+ skyreels_hunyuan_t2v_bf16.safetensors
234
+
235
+
236
+ ## to use stock hunyuan models
237
+
238
+ https://huggingface.co/tencent/HunyuanVideo/resolve/main/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt
239
+
240
+ https://huggingface.co/tencent/HunyuanVideo/resolve/main/hunyuan-video-t2v-720p/vae/pytorch_model.pt
241
+
242
+ https://huggingface.co/Comfy-Org/HunyuanVideo_repackaged/resolve/main/split_files/text_encoders/llava_llama3_fp16.safetensors
243
+
244
+ https://huggingface.co/Comfy-Org/HunyuanVideo_repackaged/resolve/main/split_files/text_encoders/clip_l.safetensors
245
+
246
+ #fp8 dit model
247
+
248
+ https://huggingface.co/kohya-ss/HunyuanVideo-fp8_e4m3fn-unofficial/resolve/main/mp_rank_00_model_states_fp8.safetensors
249
+
250
+ place models in H1111/hunyuan folder
251
+
252
+ ### Optional: Install Xformers
253
+ ```powershell
254
+ pip install --no-deps xformers --index-url https://download.pytorch.org/whl/cu124
255
+ ```
256
+
257
+ ### Optional: Install Flash Attention
258
+ Note: This can take 1-5 hour to install even on a good CPU, but provides faster generation.
259
+ I have uploaded a wheel for windows users to match cuda 12.4 and python 3.10.(thanks lldacing)
260
+ https://huggingface.co/maybleMyers/wan_files_for_h1111/resolve/main/flash_attn-2.7.4%2Bcu124torch2.5.1cxx11abiFALSE-cp310-cp310-win_amd64.whl?download=true
261
+
262
+ ```powershell
263
+ pip install flash-attn --no-build-isolation
264
+
265
+ If you have downloaded the wheel you can install it with:
266
+
267
+ pip install "flash_attn-2.7.4+cu124torch2.5.1cxx11abiFALSE-cp310-cp310-win_amd64.whl"
268
+ ```
269
+ ```
base_hv_generate_video.py ADDED
@@ -0,0 +1,936 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from datetime import datetime
3
+ from pathlib import Path
4
+ import random
5
+ import sys
6
+ import os
7
+ import time
8
+ from typing import Optional, Union
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torchvision
13
+ import accelerate
14
+ from diffusers.utils.torch_utils import randn_tensor
15
+ from transformers.models.llama import LlamaModel
16
+ from tqdm import tqdm
17
+ import av
18
+ from einops import rearrange
19
+ from safetensors.torch import load_file, save_file
20
+ from safetensors import safe_open
21
+ from PIL import Image
22
+
23
+ from hunyuan_model import vae
24
+ from hunyuan_model.text_encoder import TextEncoder
25
+ from hunyuan_model.text_encoder import PROMPT_TEMPLATE
26
+ from hunyuan_model.vae import load_vae
27
+ from hunyuan_model.models import load_transformer, get_rotary_pos_embed
28
+ from hunyuan_model.fp8_optimization import convert_fp8_linear
29
+ from modules.scheduling_flow_match_discrete import FlowMatchDiscreteScheduler
30
+ from networks import lora
31
+
32
+ try:
33
+ from lycoris.kohya import create_network_from_weights
34
+ except:
35
+ pass
36
+
37
+ from utils.model_utils import str_to_dtype
38
+ from utils.safetensors_utils import mem_eff_save_file
39
+ from dataset.image_video_dataset import load_video, glob_images, resize_image_to_bucket
40
+
41
+ import logging
42
+
43
+ logger = logging.getLogger(__name__)
44
+ logging.basicConfig(level=logging.INFO)
45
+
46
+
47
+ def clean_memory_on_device(device):
48
+ if device.type == "cuda":
49
+ torch.cuda.empty_cache()
50
+ elif device.type == "cpu":
51
+ pass
52
+ elif device.type == "mps": # not tested
53
+ torch.mps.empty_cache()
54
+
55
+
56
+ def synchronize_device(device: torch.device):
57
+ if device.type == "cuda":
58
+ torch.cuda.synchronize()
59
+ elif device.type == "xpu":
60
+ torch.xpu.synchronize()
61
+ elif device.type == "mps":
62
+ torch.mps.synchronize()
63
+
64
+
65
+ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=1, fps=24):
66
+ """save videos by video tensor
67
+ copy from https://github.com/guoyww/AnimateDiff/blob/e92bd5671ba62c0d774a32951453e328018b7c5b/animatediff/utils/util.py#L61
68
+
69
+ Args:
70
+ videos (torch.Tensor): video tensor predicted by the model
71
+ path (str): path to save video
72
+ rescale (bool, optional): rescale the video tensor from [-1, 1] to . Defaults to False.
73
+ n_rows (int, optional): Defaults to 1.
74
+ fps (int, optional): video save fps. Defaults to 8.
75
+ """
76
+ videos = rearrange(videos, "b c t h w -> t b c h w")
77
+ outputs = []
78
+ for x in videos:
79
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
80
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
81
+ if rescale:
82
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
83
+ x = torch.clamp(x, 0, 1)
84
+ x = (x * 255).numpy().astype(np.uint8)
85
+ outputs.append(x)
86
+
87
+ os.makedirs(os.path.dirname(path), exist_ok=True)
88
+
89
+ # # save video with av
90
+ # container = av.open(path, "w")
91
+ # stream = container.add_stream("libx264", rate=fps)
92
+ # for x in outputs:
93
+ # frame = av.VideoFrame.from_ndarray(x, format="rgb24")
94
+ # packet = stream.encode(frame)
95
+ # container.mux(packet)
96
+ # packet = stream.encode(None)
97
+ # container.mux(packet)
98
+ # container.close()
99
+
100
+ height, width, _ = outputs[0].shape
101
+
102
+ # create output container
103
+ container = av.open(path, mode="w")
104
+
105
+ # create video stream
106
+ codec = "libx264"
107
+ pixel_format = "yuv420p"
108
+ stream = container.add_stream(codec, rate=fps)
109
+ stream.width = width
110
+ stream.height = height
111
+ stream.pix_fmt = pixel_format
112
+ stream.bit_rate = 4000000 # 4Mbit/s
113
+
114
+ for frame_array in outputs:
115
+ frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
116
+ packets = stream.encode(frame)
117
+ for packet in packets:
118
+ container.mux(packet)
119
+
120
+ for packet in stream.encode():
121
+ container.mux(packet)
122
+
123
+ container.close()
124
+
125
+
126
+ def save_images_grid(
127
+ videos: torch.Tensor, parent_dir: str, image_name: str, rescale: bool = False, n_rows: int = 1, create_subdir=True
128
+ ):
129
+ videos = rearrange(videos, "b c t h w -> t b c h w")
130
+ outputs = []
131
+ for x in videos:
132
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
133
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
134
+ if rescale:
135
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
136
+ x = torch.clamp(x, 0, 1)
137
+ x = (x * 255).numpy().astype(np.uint8)
138
+ outputs.append(x)
139
+
140
+ if create_subdir:
141
+ output_dir = os.path.join(parent_dir, image_name)
142
+ else:
143
+ output_dir = parent_dir
144
+
145
+ os.makedirs(output_dir, exist_ok=True)
146
+ for i, x in enumerate(outputs):
147
+ image_path = os.path.join(output_dir, f"{image_name}_{i:03d}.png")
148
+ image = Image.fromarray(x)
149
+ image.save(image_path)
150
+
151
+
152
+ # region Encoding prompt
153
+
154
+
155
+ def encode_prompt(prompt: Union[str, list[str]], device: torch.device, num_videos_per_prompt: int, text_encoder: TextEncoder):
156
+ r"""
157
+ Encodes the prompt into text encoder hidden states.
158
+
159
+ Args:
160
+ prompt (`str` or `List[str]`):
161
+ prompt to be encoded
162
+ device: (`torch.device`):
163
+ torch device
164
+ num_videos_per_prompt (`int`):
165
+ number of videos that should be generated per prompt
166
+ text_encoder (TextEncoder):
167
+ text encoder to be used for encoding the prompt
168
+ """
169
+ # LoRA and Textual Inversion are not supported in this script
170
+ # negative prompt and prompt embedding are not supported in this script
171
+ # clip_skip is not supported in this script because it is not used in the original script
172
+ data_type = "video" # video only, image is not supported
173
+
174
+ text_inputs = text_encoder.text2tokens(prompt, data_type=data_type)
175
+
176
+ with torch.no_grad():
177
+ prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type, device=device)
178
+ prompt_embeds = prompt_outputs.hidden_state
179
+
180
+ attention_mask = prompt_outputs.attention_mask
181
+ if attention_mask is not None:
182
+ attention_mask = attention_mask.to(device)
183
+ bs_embed, seq_len = attention_mask.shape
184
+ attention_mask = attention_mask.repeat(1, num_videos_per_prompt)
185
+ attention_mask = attention_mask.view(bs_embed * num_videos_per_prompt, seq_len)
186
+
187
+ prompt_embeds_dtype = text_encoder.dtype
188
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
189
+
190
+ if prompt_embeds.ndim == 2:
191
+ bs_embed, _ = prompt_embeds.shape
192
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
193
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt)
194
+ prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, -1)
195
+ else:
196
+ bs_embed, seq_len, _ = prompt_embeds.shape
197
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
198
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
199
+ prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)
200
+
201
+ return prompt_embeds, attention_mask
202
+
203
+
204
+ def encode_input_prompt(prompt: Union[str, list[str]], args, device, fp8_llm=False, accelerator=None):
205
+ # constants
206
+ prompt_template_video = "dit-llm-encode-video"
207
+ prompt_template = "dit-llm-encode"
208
+ text_encoder_dtype = torch.float16
209
+ text_encoder_type = "llm"
210
+ text_len = 256
211
+ hidden_state_skip_layer = 2
212
+ apply_final_norm = False
213
+ reproduce = False
214
+
215
+ text_encoder_2_type = "clipL"
216
+ text_len_2 = 77
217
+
218
+ num_videos = 1
219
+
220
+ # if args.prompt_template_video is not None:
221
+ # crop_start = PROMPT_TEMPLATE[args.prompt_template_video].get("crop_start", 0)
222
+ # elif args.prompt_template is not None:
223
+ # crop_start = PROMPT_TEMPLATE[args.prompt_template].get("crop_start", 0)
224
+ # else:
225
+ # crop_start = 0
226
+ crop_start = PROMPT_TEMPLATE[prompt_template_video].get("crop_start", 0)
227
+ max_length = text_len + crop_start
228
+
229
+ # prompt_template
230
+ prompt_template = PROMPT_TEMPLATE[prompt_template]
231
+
232
+ # prompt_template_video
233
+ prompt_template_video = PROMPT_TEMPLATE[prompt_template_video] # if args.prompt_template_video is not None else None
234
+
235
+ # load text encoders
236
+ logger.info(f"loading text encoder: {args.text_encoder1}")
237
+ text_encoder = TextEncoder(
238
+ text_encoder_type=text_encoder_type,
239
+ max_length=max_length,
240
+ text_encoder_dtype=text_encoder_dtype,
241
+ text_encoder_path=args.text_encoder1,
242
+ tokenizer_type=text_encoder_type,
243
+ prompt_template=prompt_template,
244
+ prompt_template_video=prompt_template_video,
245
+ hidden_state_skip_layer=hidden_state_skip_layer,
246
+ apply_final_norm=apply_final_norm,
247
+ reproduce=reproduce,
248
+ )
249
+ text_encoder.eval()
250
+ if fp8_llm:
251
+ org_dtype = text_encoder.dtype
252
+ logger.info(f"Moving and casting text encoder to {device} and torch.float8_e4m3fn")
253
+ text_encoder.to(device=device, dtype=torch.float8_e4m3fn)
254
+
255
+ # prepare LLM for fp8
256
+ def prepare_fp8(llama_model: LlamaModel, target_dtype):
257
+ def forward_hook(module):
258
+ def forward(hidden_states):
259
+ input_dtype = hidden_states.dtype
260
+ hidden_states = hidden_states.to(torch.float32)
261
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
262
+ hidden_states = hidden_states * torch.rsqrt(variance + module.variance_epsilon)
263
+ return module.weight.to(input_dtype) * hidden_states.to(input_dtype)
264
+
265
+ return forward
266
+
267
+ for module in llama_model.modules():
268
+ if module.__class__.__name__ in ["Embedding"]:
269
+ # print("set", module.__class__.__name__, "to", target_dtype)
270
+ module.to(target_dtype)
271
+ if module.__class__.__name__ in ["LlamaRMSNorm"]:
272
+ # print("set", module.__class__.__name__, "hooks")
273
+ module.forward = forward_hook(module)
274
+
275
+ prepare_fp8(text_encoder.model, org_dtype)
276
+
277
+ logger.info(f"loading text encoder 2: {args.text_encoder2}")
278
+ text_encoder_2 = TextEncoder(
279
+ text_encoder_type=text_encoder_2_type,
280
+ max_length=text_len_2,
281
+ text_encoder_dtype=text_encoder_dtype,
282
+ text_encoder_path=args.text_encoder2,
283
+ tokenizer_type=text_encoder_2_type,
284
+ reproduce=reproduce,
285
+ )
286
+ text_encoder_2.eval()
287
+
288
+ # encode prompt
289
+ logger.info(f"Encoding prompt with text encoder 1")
290
+ text_encoder.to(device=device)
291
+ if fp8_llm:
292
+ with accelerator.autocast():
293
+ prompt_embeds, prompt_mask = encode_prompt(prompt, device, num_videos, text_encoder)
294
+ else:
295
+ prompt_embeds, prompt_mask = encode_prompt(prompt, device, num_videos, text_encoder)
296
+ text_encoder = None
297
+ clean_memory_on_device(device)
298
+
299
+ logger.info(f"Encoding prompt with text encoder 2")
300
+ text_encoder_2.to(device=device)
301
+ prompt_embeds_2, prompt_mask_2 = encode_prompt(prompt, device, num_videos, text_encoder_2)
302
+
303
+ prompt_embeds = prompt_embeds.to("cpu")
304
+ prompt_mask = prompt_mask.to("cpu")
305
+ prompt_embeds_2 = prompt_embeds_2.to("cpu")
306
+ prompt_mask_2 = prompt_mask_2.to("cpu")
307
+
308
+ text_encoder_2 = None
309
+ clean_memory_on_device(device)
310
+
311
+ return prompt_embeds, prompt_mask, prompt_embeds_2, prompt_mask_2
312
+
313
+
314
+ # endregion
315
+
316
+
317
+ def prepare_vae(args, device):
318
+ vae_dtype = torch.float16 if args.vae_dtype is None else str_to_dtype(args.vae_dtype)
319
+ vae, _, s_ratio, t_ratio = load_vae(vae_dtype=vae_dtype, device=device, vae_path=args.vae)
320
+ vae.eval()
321
+ # vae_kwargs = {"s_ratio": s_ratio, "t_ratio": t_ratio}
322
+
323
+ # set chunk_size to CausalConv3d recursively
324
+ chunk_size = args.vae_chunk_size
325
+ if chunk_size is not None:
326
+ vae.set_chunk_size_for_causal_conv_3d(chunk_size)
327
+ logger.info(f"Set chunk_size to {chunk_size} for CausalConv3d")
328
+
329
+ if args.vae_spatial_tile_sample_min_size is not None:
330
+ vae.enable_spatial_tiling(True)
331
+ vae.tile_sample_min_size = args.vae_spatial_tile_sample_min_size
332
+ vae.tile_latent_min_size = args.vae_spatial_tile_sample_min_size // 8
333
+ # elif args.vae_tiling:
334
+ else:
335
+ vae.enable_spatial_tiling(True)
336
+
337
+ return vae, vae_dtype
338
+
339
+
340
+ def encode_to_latents(args, video, device):
341
+ vae, vae_dtype = prepare_vae(args, device)
342
+
343
+ video = video.to(device=device, dtype=vae_dtype)
344
+ video = video * 2 - 1 # 0, 1 -> -1, 1
345
+ with torch.no_grad():
346
+ latents = vae.encode(video).latent_dist.sample()
347
+
348
+ if hasattr(vae.config, "shift_factor") and vae.config.shift_factor:
349
+ latents = (latents - vae.config.shift_factor) * vae.config.scaling_factor
350
+ else:
351
+ latents = latents * vae.config.scaling_factor
352
+
353
+ return latents
354
+
355
+
356
+ def decode_latents(args, latents, device):
357
+ vae, vae_dtype = prepare_vae(args, device)
358
+
359
+ expand_temporal_dim = False
360
+ if len(latents.shape) == 4:
361
+ latents = latents.unsqueeze(2)
362
+ expand_temporal_dim = True
363
+ elif len(latents.shape) == 5:
364
+ pass
365
+ else:
366
+ raise ValueError(f"Only support latents with shape (b, c, h, w) or (b, c, f, h, w), but got {latents.shape}.")
367
+
368
+ if hasattr(vae.config, "shift_factor") and vae.config.shift_factor:
369
+ latents = latents / vae.config.scaling_factor + vae.config.shift_factor
370
+ else:
371
+ latents = latents / vae.config.scaling_factor
372
+
373
+ latents = latents.to(device=device, dtype=vae_dtype)
374
+ with torch.no_grad():
375
+ image = vae.decode(latents, return_dict=False)[0]
376
+
377
+ if expand_temporal_dim:
378
+ image = image.squeeze(2)
379
+
380
+ image = (image / 2 + 0.5).clamp(0, 1)
381
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
382
+ image = image.cpu().float()
383
+
384
+ return image
385
+
386
+
387
+ def parse_args():
388
+ parser = argparse.ArgumentParser(description="HunyuanVideo inference script")
389
+
390
+ parser.add_argument("--dit", type=str, required=True, help="DiT checkpoint path or directory")
391
+ parser.add_argument(
392
+ "--dit_in_channels",
393
+ type=int,
394
+ default=None,
395
+ help="input channels for DiT, default is None (automatically detect). 32 for SkyReels-I2V, 16 for others",
396
+ )
397
+ parser.add_argument("--vae", type=str, required=True, help="VAE checkpoint path or directory")
398
+ parser.add_argument("--vae_dtype", type=str, default=None, help="data type for VAE, default is float16")
399
+ parser.add_argument("--text_encoder1", type=str, required=True, help="Text Encoder 1 directory")
400
+ parser.add_argument("--text_encoder2", type=str, required=True, help="Text Encoder 2 directory")
401
+
402
+ # LoRA
403
+ parser.add_argument("--lora_weight", type=str, nargs="*", required=False, default=None, help="LoRA weight path")
404
+ parser.add_argument("--lora_multiplier", type=float, nargs="*", default=1.0, help="LoRA multiplier")
405
+ parser.add_argument(
406
+ "--save_merged_model",
407
+ type=str,
408
+ default=None,
409
+ help="Save merged model to path. If specified, no inference will be performed.",
410
+ )
411
+ parser.add_argument("--exclude_single_blocks", action="store_true", help="Exclude single blocks when loading LoRA weights")
412
+
413
+ # inference
414
+ parser.add_argument("--prompt", type=str, required=True, help="prompt for generation")
415
+ parser.add_argument("--negative_prompt", type=str, default=None, help="negative prompt for generation")
416
+ parser.add_argument("--video_size", type=int, nargs=2, default=[256, 256], help="video size")
417
+ parser.add_argument("--video_length", type=int, default=129, help="video length")
418
+ parser.add_argument("--fps", type=int, default=24, help="video fps")
419
+ parser.add_argument("--infer_steps", type=int, default=50, help="number of inference steps")
420
+ parser.add_argument("--save_path", type=str, required=True, help="path to save generated video")
421
+ parser.add_argument("--seed", type=int, default=None, help="Seed for evaluation.")
422
+ parser.add_argument(
423
+ "--guidance_scale",
424
+ type=float,
425
+ default=1.0,
426
+ help="Guidance scale for classifier free guidance. Default is 1.0 (means no guidance)",
427
+ )
428
+ parser.add_argument("--embedded_cfg_scale", type=float, default=6.0, help="Embeded classifier free guidance scale.")
429
+ parser.add_argument("--video_path", type=str, default=None, help="path to video for video2video inference")
430
+ parser.add_argument(
431
+ "--image_path", type=str, default=None, help="path to image for image2video inference, only works for SkyReels-I2V model"
432
+ )
433
+ parser.add_argument(
434
+ "--split_uncond",
435
+ action="store_true",
436
+ help="split unconditional call for classifier free guidance, slower but less memory usage",
437
+ )
438
+ parser.add_argument("--strength", type=float, default=0.8, help="strength for video2video inference")
439
+
440
+ # Flow Matching
441
+ parser.add_argument("--flow_shift", type=float, default=7.0, help="Shift factor for flow matching schedulers.")
442
+
443
+ parser.add_argument("--fp8", action="store_true", help="use fp8 for DiT model")
444
+ parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for Text Encoder 1 (LLM)")
445
+ parser.add_argument(
446
+ "--device", type=str, default=None, help="device to use for inference. If None, use CUDA if available, otherwise use CPU"
447
+ )
448
+ parser.add_argument(
449
+ "--attn_mode", type=str, default="torch", choices=["flash", "torch", "sageattn", "xformers", "sdpa"], help="attention mode"
450
+ )
451
+ parser.add_argument(
452
+ "--split_attn", action="store_true", help="use split attention, default is False. if True, --split_uncond becomes True"
453
+ )
454
+ parser.add_argument("--vae_chunk_size", type=int, default=None, help="chunk size for CausalConv3d in VAE")
455
+ parser.add_argument(
456
+ "--vae_spatial_tile_sample_min_size", type=int, default=None, help="spatial tile sample min size for VAE, default 256"
457
+ )
458
+ parser.add_argument("--blocks_to_swap", type=int, default=None, help="number of blocks to swap in the model")
459
+ parser.add_argument("--img_in_txt_in_offloading", action="store_true", help="offload img_in and txt_in to cpu")
460
+ parser.add_argument(
461
+ "--output_type", type=str, default="video", choices=["video", "images", "latent", "both"], help="output type"
462
+ )
463
+ parser.add_argument("--no_metadata", action="store_true", help="do not save metadata")
464
+ parser.add_argument("--latent_path", type=str, nargs="*", default=None, help="path to latent for decode. no inference")
465
+ parser.add_argument("--lycoris", action="store_true", help="use lycoris for inference")
466
+ parser.add_argument("--fp8_fast", action="store_true", help="Enable fast FP8 arthimetic(RTX 4XXX+)")
467
+ parser.add_argument("--compile", action="store_true", help="Enable torch.compile")
468
+ parser.add_argument(
469
+ "--compile_args",
470
+ nargs=4,
471
+ metavar=("BACKEND", "MODE", "DYNAMIC", "FULLGRAPH"),
472
+ default=["inductor", "max-autotune-no-cudagraphs", "False", "False"],
473
+ help="Torch.compile settings",
474
+ )
475
+
476
+ args = parser.parse_args()
477
+
478
+ assert (args.latent_path is None or len(args.latent_path) == 0) or (
479
+ args.output_type == "images" or args.output_type == "video"
480
+ ), "latent_path is only supported for images or video output"
481
+
482
+ # update dit_weight based on model_base if not exists
483
+
484
+ if args.fp8_fast and not args.fp8:
485
+ raise ValueError("--fp8_fast requires --fp8")
486
+
487
+ return args
488
+
489
+
490
+ def check_inputs(args):
491
+ height = args.video_size[0]
492
+ width = args.video_size[1]
493
+ video_length = args.video_length
494
+
495
+ if height % 8 != 0 or width % 8 != 0:
496
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
497
+ return height, width, video_length
498
+
499
+
500
+ def main():
501
+ args = parse_args()
502
+
503
+ device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
504
+ device = torch.device(device)
505
+ dit_dtype = torch.bfloat16
506
+ dit_weight_dtype = torch.float8_e4m3fn if args.fp8 else dit_dtype
507
+ logger.info(f"Using device: {device}, DiT precision: {dit_dtype}, weight precision: {dit_weight_dtype}")
508
+
509
+ original_base_names = None
510
+ if args.latent_path is not None and len(args.latent_path) > 0:
511
+ original_base_names = []
512
+ latents_list = []
513
+ seeds = []
514
+ for latent_path in args.latent_path:
515
+ original_base_names.append(os.path.splitext(os.path.basename(latent_path))[0])
516
+ seed = 0
517
+
518
+ if os.path.splitext(latent_path)[1] != ".safetensors":
519
+ latents = torch.load(latent_path, map_location="cpu")
520
+ else:
521
+ latents = load_file(latent_path)["latent"]
522
+ with safe_open(latent_path, framework="pt") as f:
523
+ metadata = f.metadata()
524
+ if metadata is None:
525
+ metadata = {}
526
+ logger.info(f"Loaded metadata: {metadata}")
527
+
528
+ if "seeds" in metadata:
529
+ seed = int(metadata["seeds"])
530
+
531
+ seeds.append(seed)
532
+ latents_list.append(latents)
533
+
534
+ logger.info(f"Loaded latent from {latent_path}. Shape: {latents.shape}")
535
+ latents = torch.stack(latents_list, dim=0)
536
+ else:
537
+ # prepare accelerator
538
+ mixed_precision = "bf16" if dit_dtype == torch.bfloat16 else "fp16"
539
+ accelerator = accelerate.Accelerator(mixed_precision=mixed_precision)
540
+
541
+ # load prompt
542
+ prompt = args.prompt # TODO load prompts from file
543
+ assert prompt is not None, "prompt is required"
544
+
545
+ # check inputs: may be height, width, video_length etc will be changed for each generation in future
546
+ height, width, video_length = check_inputs(args)
547
+
548
+ # encode prompt with LLM and Text Encoder
549
+ logger.info(f"Encoding prompt: {prompt}")
550
+
551
+ do_classifier_free_guidance = args.guidance_scale != 1.0
552
+ if do_classifier_free_guidance:
553
+ negative_prompt = args.negative_prompt
554
+ if negative_prompt is None:
555
+ logger.info("Negative prompt is not provided, using empty prompt")
556
+ negative_prompt = ""
557
+ logger.info(f"Encoding negative prompt: {negative_prompt}")
558
+ prompt = [negative_prompt, prompt]
559
+ else:
560
+ if args.negative_prompt is not None:
561
+ logger.warning("Negative prompt is provided but guidance_scale is 1.0, negative prompt will be ignored.")
562
+
563
+ prompt_embeds, prompt_mask, prompt_embeds_2, prompt_mask_2 = encode_input_prompt(
564
+ prompt, args, device, args.fp8_llm, accelerator
565
+ )
566
+
567
+ # encode latents for video2video inference
568
+ video_latents = None
569
+ if args.video_path is not None:
570
+ # v2v inference
571
+ logger.info(f"Video2Video inference: {args.video_path}")
572
+ video = load_video(args.video_path, 0, video_length, bucket_reso=(width, height)) # list of frames
573
+ if len(video) < video_length:
574
+ raise ValueError(f"Video length is less than {video_length}")
575
+ video = np.stack(video, axis=0) # F, H, W, C
576
+ video = torch.from_numpy(video).permute(3, 0, 1, 2).unsqueeze(0).float() # 1, C, F, H, W
577
+ video = video / 255.0
578
+
579
+ logger.info(f"Encoding video to latents")
580
+ video_latents = encode_to_latents(args, video, device)
581
+ video_latents = video_latents.to(device=device, dtype=dit_dtype)
582
+
583
+ clean_memory_on_device(device)
584
+
585
+ # encode latents for image2video inference
586
+ image_latents = None
587
+ if args.image_path is not None:
588
+ # i2v inference
589
+ logger.info(f"Image2Video inference: {args.image_path}")
590
+
591
+ image = Image.open(args.image_path)
592
+ image = resize_image_to_bucket(image, (width, height)) # returns a numpy array
593
+ image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).unsqueeze(2).float() # 1, C, 1, H, W
594
+ image = image / 255.0
595
+
596
+ logger.info(f"Encoding image to latents")
597
+ image_latents = encode_to_latents(args, image, device) # 1, C, 1, H, W
598
+ image_latents = image_latents.to(device=device, dtype=dit_dtype)
599
+
600
+ clean_memory_on_device(device)
601
+
602
+ # load DiT model
603
+ blocks_to_swap = args.blocks_to_swap if args.blocks_to_swap else 0
604
+ loading_device = "cpu" # if blocks_to_swap > 0 else device
605
+
606
+ logger.info(f"Loading DiT model from {args.dit}")
607
+ if args.attn_mode == "sdpa":
608
+ args.attn_mode = "torch"
609
+
610
+ # if image_latents is given, the model should be I2V model, so the in_channels should be 32
611
+ dit_in_channels = args.dit_in_channels if args.dit_in_channels is not None else (32 if image_latents is not None else 16)
612
+
613
+ # if we use LoRA, weigths should be bf16 instead of fp8, because merging should be done in bf16
614
+ # the model is too large, so we load the model to cpu. in addition, the .pt file is loaded to cpu anyway
615
+ # on the fly merging will be a solution for this issue for .safetenors files (not implemented yet)
616
+ transformer = load_transformer(
617
+ args.dit, args.attn_mode, args.split_attn, loading_device, dit_dtype, in_channels=dit_in_channels
618
+ )
619
+ transformer.eval()
620
+
621
+ # load LoRA weights
622
+ if args.lora_weight is not None and len(args.lora_weight) > 0:
623
+ for i, lora_weight in enumerate(args.lora_weight):
624
+ if args.lora_multiplier is not None and len(args.lora_multiplier) > i:
625
+ lora_multiplier = args.lora_multiplier[i]
626
+ else:
627
+ lora_multiplier = 1.0
628
+
629
+ logger.info(f"Loading LoRA weights from {lora_weight} with multiplier {lora_multiplier}")
630
+ weights_sd = load_file(lora_weight)
631
+
632
+ # Filter to exclude keys that are part of single_blocks
633
+ if args.exclude_single_blocks:
634
+ filtered_weights = {k: v for k, v in weights_sd.items() if "single_blocks" not in k}
635
+ weights_sd = filtered_weights
636
+
637
+ if args.lycoris:
638
+ lycoris_net, _ = create_network_from_weights(
639
+ multiplier=lora_multiplier,
640
+ file=None,
641
+ weights_sd=weights_sd,
642
+ unet=transformer,
643
+ text_encoder=None,
644
+ vae=None,
645
+ for_inference=True,
646
+ )
647
+ else:
648
+ network = lora.create_arch_network_from_weights(
649
+ lora_multiplier, weights_sd, unet=transformer, for_inference=True
650
+ )
651
+ logger.info("Merging LoRA weights to DiT model")
652
+
653
+ # try:
654
+ # network.apply_to(None, transformer, apply_text_encoder=False, apply_unet=True)
655
+ # info = network.load_state_dict(weights_sd, strict=True)
656
+ # logger.info(f"Loaded LoRA weights from {weights_file}: {info}")
657
+ # network.eval()
658
+ # network.to(device)
659
+ # except Exception as e:
660
+ if args.lycoris:
661
+ lycoris_net.merge_to(None, transformer, weights_sd, dtype=None, device=device)
662
+ else:
663
+ network.merge_to(None, transformer, weights_sd, device=device, non_blocking=True)
664
+
665
+ synchronize_device(device)
666
+
667
+ logger.info("LoRA weights loaded")
668
+
669
+ # save model here before casting to dit_weight_dtype
670
+ if args.save_merged_model:
671
+ logger.info(f"Saving merged model to {args.save_merged_model}")
672
+ mem_eff_save_file(transformer.state_dict(), args.save_merged_model) # save_file needs a lot of memory
673
+ logger.info("Merged model saved")
674
+ return
675
+
676
+ logger.info(f"Casting model to {dit_weight_dtype}")
677
+ transformer.to(dtype=dit_weight_dtype)
678
+
679
+ if args.fp8_fast:
680
+ logger.info("Enabling FP8 acceleration")
681
+ params_to_keep = {"norm", "bias", "time_in", "vector_in", "guidance_in", "txt_in", "img_in"}
682
+ for name, param in transformer.named_parameters():
683
+ dtype_to_use = dit_dtype if any(keyword in name for keyword in params_to_keep) else dit_weight_dtype
684
+ param.to(dtype=dtype_to_use)
685
+ convert_fp8_linear(transformer, dit_dtype, params_to_keep=params_to_keep)
686
+
687
+ if args.compile:
688
+ compile_backend, compile_mode, compile_dynamic, compile_fullgraph = args.compile_args
689
+ logger.info(
690
+ f"Torch Compiling[Backend: {compile_backend}; Mode: {compile_mode}; Dynamic: {compile_dynamic}; Fullgraph: {compile_fullgraph}]"
691
+ )
692
+ torch._dynamo.config.cache_size_limit = 32
693
+ for i, block in enumerate(transformer.single_blocks):
694
+ compiled_block = torch.compile(
695
+ block,
696
+ backend=compile_backend,
697
+ mode=compile_mode,
698
+ dynamic=compile_dynamic.lower() in "true",
699
+ fullgraph=compile_fullgraph.lower() in "true",
700
+ )
701
+ transformer.single_blocks[i] = compiled_block
702
+ for i, block in enumerate(transformer.double_blocks):
703
+ compiled_block = torch.compile(
704
+ block,
705
+ backend=compile_backend,
706
+ mode=compile_mode,
707
+ dynamic=compile_dynamic.lower() in "true",
708
+ fullgraph=compile_fullgraph.lower() in "true",
709
+ )
710
+ transformer.double_blocks[i] = compiled_block
711
+
712
+ if blocks_to_swap > 0:
713
+ logger.info(f"Enable swap {blocks_to_swap} blocks to CPU from device: {device}")
714
+ transformer.enable_block_swap(blocks_to_swap, device, supports_backward=False)
715
+ transformer.move_to_device_except_swap_blocks(device)
716
+ transformer.prepare_block_swap_before_forward()
717
+ else:
718
+ logger.info(f"Moving model to {device}")
719
+ transformer.to(device=device)
720
+ if args.img_in_txt_in_offloading:
721
+ logger.info("Enable offloading img_in and txt_in to CPU")
722
+ transformer.enable_img_in_txt_in_offloading()
723
+
724
+ # load scheduler
725
+ logger.info(f"Loading scheduler")
726
+ scheduler = FlowMatchDiscreteScheduler(shift=args.flow_shift, reverse=True, solver="euler")
727
+
728
+ # Prepare timesteps
729
+ num_inference_steps = args.infer_steps
730
+ scheduler.set_timesteps(num_inference_steps, device=device) # n_tokens is not used in FlowMatchDiscreteScheduler
731
+ timesteps = scheduler.timesteps
732
+
733
+ # Prepare generator
734
+ num_videos_per_prompt = 1 # args.num_videos # currently only support 1 video per prompt, this is a batch size
735
+ seed = args.seed
736
+ if seed is None:
737
+ seeds = [random.randint(0, 2**32 - 1) for _ in range(num_videos_per_prompt)]
738
+ elif isinstance(seed, int):
739
+ seeds = [seed + i for i in range(num_videos_per_prompt)]
740
+ else:
741
+ raise ValueError(f"Seed must be an integer or None, got {seed}.")
742
+ generator = [torch.Generator(device).manual_seed(seed) for seed in seeds]
743
+
744
+ # Prepare noisy latents
745
+ num_channels_latents = 16 # transformer.config.in_channels
746
+ vae_scale_factor = 2 ** (4 - 1) # len(self.vae.config.block_out_channels) == 4
747
+
748
+ vae_ver = vae.VAE_VER
749
+ if "884" in vae_ver:
750
+ latent_video_length = (video_length - 1) // 4 + 1
751
+ elif "888" in vae_ver:
752
+ latent_video_length = (video_length - 1) // 8 + 1
753
+ else:
754
+ latent_video_length = video_length
755
+
756
+ # shape = (
757
+ # num_videos_per_prompt,
758
+ # num_channels_latents,
759
+ # latent_video_length,
760
+ # height // vae_scale_factor,
761
+ # width // vae_scale_factor,
762
+ # )
763
+ # latents = randn_tensor(shape, generator=generator, device=device, dtype=dit_dtype)
764
+
765
+ # make first N frames to be the same if the given seed is same
766
+ shape_of_frame = (num_videos_per_prompt, num_channels_latents, 1, height // vae_scale_factor, width // vae_scale_factor)
767
+ latents = []
768
+ for i in range(latent_video_length):
769
+ latents.append(randn_tensor(shape_of_frame, generator=generator, device=device, dtype=dit_dtype))
770
+ latents = torch.cat(latents, dim=2)
771
+
772
+ # pad image_latents to match the length of video_latents
773
+ if image_latents is not None:
774
+ zero_latents = torch.zeros_like(latents)
775
+ zero_latents[:, :, :1, :, :] = image_latents
776
+ image_latents = zero_latents
777
+
778
+ if args.video_path is not None:
779
+ # v2v inference
780
+ noise = latents
781
+ assert noise.shape == video_latents.shape, f"noise shape {noise.shape} != video_latents shape {video_latents.shape}"
782
+
783
+ num_inference_steps = int(num_inference_steps * args.strength)
784
+ timestep_start = scheduler.timesteps[-num_inference_steps] # larger strength, less inference steps and more start time
785
+ t = timestep_start / 1000.0
786
+ latents = noise * t + video_latents * (1 - t)
787
+
788
+ timesteps = timesteps[-num_inference_steps:]
789
+
790
+ logger.info(f"strength: {args.strength}, num_inference_steps: {num_inference_steps}, timestep_start: {timestep_start}")
791
+
792
+ # FlowMatchDiscreteScheduler does not have init_noise_sigma
793
+
794
+ # Denoising loop
795
+ embedded_guidance_scale = args.embedded_cfg_scale
796
+ if embedded_guidance_scale is not None:
797
+ guidance_expand = torch.tensor([embedded_guidance_scale * 1000.0] * latents.shape[0], dtype=torch.float32, device="cpu")
798
+ guidance_expand = guidance_expand.to(device=device, dtype=dit_dtype)
799
+ if do_classifier_free_guidance:
800
+ guidance_expand = torch.cat([guidance_expand, guidance_expand], dim=0)
801
+ else:
802
+ guidance_expand = None
803
+ freqs_cos, freqs_sin = get_rotary_pos_embed(vae_ver, transformer, video_length, height, width)
804
+ # n_tokens = freqs_cos.shape[0]
805
+
806
+ # move and cast all inputs to the correct device and dtype
807
+ prompt_embeds = prompt_embeds.to(device=device, dtype=dit_dtype)
808
+ prompt_mask = prompt_mask.to(device=device)
809
+ prompt_embeds_2 = prompt_embeds_2.to(device=device, dtype=dit_dtype)
810
+ prompt_mask_2 = prompt_mask_2.to(device=device)
811
+
812
+ freqs_cos = freqs_cos.to(device=device, dtype=dit_dtype)
813
+ freqs_sin = freqs_sin.to(device=device, dtype=dit_dtype)
814
+
815
+ num_warmup_steps = len(timesteps) - num_inference_steps * scheduler.order # this should be 0 in v2v inference
816
+
817
+ # assert split_uncond and split_attn
818
+ if args.split_attn and do_classifier_free_guidance and not args.split_uncond:
819
+ logger.warning("split_attn is enabled, split_uncond will be enabled as well.")
820
+ args.split_uncond = True
821
+
822
+ # with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA]) as p:
823
+ with tqdm(total=num_inference_steps) as progress_bar:
824
+ for i, t in enumerate(timesteps):
825
+ latents = scheduler.scale_model_input(latents, t)
826
+
827
+ # predict the noise residual
828
+ with torch.no_grad(), accelerator.autocast():
829
+ latents_input = latents if not do_classifier_free_guidance else torch.cat([latents, latents], dim=0)
830
+ if image_latents is not None:
831
+ latents_image_input = (
832
+ image_latents if not do_classifier_free_guidance else torch.cat([image_latents, image_latents], dim=0)
833
+ )
834
+ latents_input = torch.cat([latents_input, latents_image_input], dim=1) # 1 or 2, C*2, F, H, W
835
+
836
+ batch_size = 1 if args.split_uncond else latents_input.shape[0]
837
+
838
+ noise_pred_list = []
839
+ for j in range(0, latents_input.shape[0], batch_size):
840
+ noise_pred = transformer( # For an input image (129, 192, 336) (1, 256, 256)
841
+ latents_input[j : j + batch_size], # [1, 16, 33, 24, 42]
842
+ t.repeat(batch_size).to(device=device, dtype=dit_dtype), # [1]
843
+ text_states=prompt_embeds[j : j + batch_size], # [1, 256, 4096]
844
+ text_mask=prompt_mask[j : j + batch_size], # [1, 256]
845
+ text_states_2=prompt_embeds_2[j : j + batch_size], # [1, 768]
846
+ freqs_cos=freqs_cos, # [seqlen, head_dim]
847
+ freqs_sin=freqs_sin, # [seqlen, head_dim]
848
+ guidance=guidance_expand[j : j + batch_size], # [1]
849
+ return_dict=True,
850
+ )["x"]
851
+ noise_pred_list.append(noise_pred)
852
+ noise_pred = torch.cat(noise_pred_list, dim=0)
853
+
854
+ # perform classifier free guidance
855
+ if do_classifier_free_guidance:
856
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
857
+ noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_cond - noise_pred_uncond)
858
+
859
+ # # SkyReels' rescale noise config is omitted for now
860
+ # if guidance_rescale > 0.0:
861
+ # # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
862
+ # noise_pred = rescale_noise_cfg(
863
+ # noise_pred,
864
+ # noise_pred_cond,
865
+ # guidance_rescale=self.guidance_rescale,
866
+ # )
867
+
868
+ # compute the previous noisy sample x_t -> x_t-1
869
+ latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]
870
+
871
+ # update progress bar
872
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
873
+ if progress_bar is not None:
874
+ progress_bar.update()
875
+
876
+ # print(p.key_averages().table(sort_by="self_cpu_time_total", row_limit=-1))
877
+ # print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
878
+
879
+ latents = latents.detach().cpu()
880
+ transformer = None
881
+ clean_memory_on_device(device)
882
+
883
+ # Save samples
884
+ output_type = args.output_type
885
+ save_path = args.save_path # if args.save_path_suffix == "" else f"{args.save_path}_{args.save_path_suffix}"
886
+ os.makedirs(save_path, exist_ok=True)
887
+ time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S")
888
+
889
+ if output_type == "latent" or output_type == "both":
890
+ # save latent
891
+ for i, latent in enumerate(latents):
892
+ latent_path = f"{save_path}/{time_flag}_{i}_{seeds[i]}_latent.safetensors"
893
+
894
+ if args.no_metadata:
895
+ metadata = None
896
+ else:
897
+ metadata = {
898
+ "seeds": f"{seeds[i]}",
899
+ "prompt": f"{args.prompt}",
900
+ "height": f"{height}",
901
+ "width": f"{width}",
902
+ "video_length": f"{video_length}",
903
+ "infer_steps": f"{num_inference_steps}",
904
+ "guidance_scale": f"{args.guidance_scale}",
905
+ "embedded_cfg_scale": f"{args.embedded_cfg_scale}",
906
+ }
907
+ if args.negative_prompt is not None:
908
+ metadata["negative_prompt"] = f"{args.negative_prompt}"
909
+ sd = {"latent": latent}
910
+ save_file(sd, latent_path, metadata=metadata)
911
+
912
+ logger.info(f"Latent save to: {latent_path}")
913
+ if output_type == "video" or output_type == "both":
914
+ # save video
915
+ videos = decode_latents(args, latents, device)
916
+ for i, sample in enumerate(videos):
917
+ original_name = "" if original_base_names is None else f"_{original_base_names[i]}"
918
+ sample = sample.unsqueeze(0)
919
+ video_path = f"{save_path}/{time_flag}_{i}_{seeds[i]}{original_name}.mp4"
920
+ save_videos_grid(sample, video_path, fps=args.fps)
921
+ logger.info(f"Sample save to: {video_path}")
922
+ elif output_type == "images":
923
+ # save images
924
+ videos = decode_latents(args, latents, device)
925
+ for i, sample in enumerate(videos):
926
+ original_name = "" if original_base_names is None else f"_{original_base_names[i]}"
927
+ sample = sample.unsqueeze(0)
928
+ image_name = f"{time_flag}_{i}_{seeds[i]}{original_name}"
929
+ save_images_grid(sample, save_path, image_name)
930
+ logger.info(f"Sample images save to: {save_path}/{image_name}")
931
+
932
+ logger.info("Done!")
933
+
934
+
935
+ if __name__ == "__main__":
936
+ main()
base_wan_generate_video.py ADDED
@@ -0,0 +1,1892 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from datetime import datetime
3
+ import gc
4
+ import random
5
+ import os
6
+ import re
7
+ import time
8
+ import math
9
+ import copy
10
+ from types import ModuleType, SimpleNamespace
11
+ from typing import Tuple, Optional, List, Union, Any, Dict
12
+
13
+ import torch
14
+ import accelerate
15
+ from accelerate import Accelerator
16
+ from safetensors.torch import load_file, save_file
17
+ from safetensors import safe_open
18
+ from PIL import Image
19
+ import cv2
20
+ import numpy as np
21
+ import torchvision.transforms.functional as TF
22
+ from tqdm import tqdm
23
+
24
+ from networks import lora_wan
25
+ from utils.safetensors_utils import mem_eff_save_file, load_safetensors
26
+ from wan.configs import WAN_CONFIGS, SUPPORTED_SIZES
27
+ import wan
28
+ from wan.modules.model import WanModel, load_wan_model, detect_wan_sd_dtype
29
+ from wan.modules.vae import WanVAE
30
+ from wan.modules.t5 import T5EncoderModel
31
+ from wan.modules.clip import CLIPModel
32
+ from modules.scheduling_flow_match_discrete import FlowMatchDiscreteScheduler
33
+ from wan.utils.fm_solvers import FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps
34
+ from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
35
+
36
+ try:
37
+ from lycoris.kohya import create_network_from_weights
38
+ except:
39
+ pass
40
+
41
+ from utils.model_utils import str_to_dtype
42
+ from utils.device_utils import clean_memory_on_device
43
+ from hv_generate_video import save_images_grid, save_videos_grid, synchronize_device
44
+ from dataset.image_video_dataset import load_video
45
+
46
+ import logging
47
+
48
+ logger = logging.getLogger(__name__)
49
+ logging.basicConfig(level=logging.INFO)
50
+
51
+
52
+ class GenerationSettings:
53
+ def __init__(
54
+ self, device: torch.device, cfg, dit_dtype: torch.dtype, dit_weight_dtype: Optional[torch.dtype], vae_dtype: torch.dtype
55
+ ):
56
+ self.device = device
57
+ self.cfg = cfg
58
+ self.dit_dtype = dit_dtype
59
+ self.dit_weight_dtype = dit_weight_dtype
60
+ self.vae_dtype = vae_dtype
61
+
62
+
63
+ def parse_args() -> argparse.Namespace:
64
+ """parse command line arguments"""
65
+ parser = argparse.ArgumentParser(description="Wan 2.1 inference script")
66
+
67
+ # WAN arguments
68
+ parser.add_argument("--ckpt_dir", type=str, default=None, help="The path to the checkpoint directory (Wan 2.1 official).")
69
+ parser.add_argument("--task", type=str, default="t2v-14B", choices=list(WAN_CONFIGS.keys()), help="The task to run.")
70
+ parser.add_argument(
71
+ "--sample_solver", type=str, default="unipc", choices=["unipc", "dpm++", "vanilla"], help="The solver used to sample."
72
+ )
73
+
74
+ parser.add_argument("--dit", type=str, default=None, help="DiT checkpoint path")
75
+ parser.add_argument("--vae", type=str, default=None, help="VAE checkpoint path")
76
+ parser.add_argument("--vae_dtype", type=str, default=None, help="data type for VAE, default is bfloat16")
77
+ parser.add_argument("--vae_cache_cpu", action="store_true", help="cache features in VAE on CPU")
78
+ parser.add_argument("--t5", type=str, default=None, help="text encoder (T5) checkpoint path")
79
+ parser.add_argument("--clip", type=str, default=None, help="text encoder (CLIP) checkpoint path")
80
+ # LoRA
81
+ parser.add_argument("--lora_weight", type=str, nargs="*", required=False, default=None, help="LoRA weight path")
82
+ parser.add_argument("--lora_multiplier", type=float, nargs="*", default=1.0, help="LoRA multiplier")
83
+ parser.add_argument("--include_patterns", type=str, nargs="*", default=None, help="LoRA module include patterns")
84
+ parser.add_argument("--exclude_patterns", type=str, nargs="*", default=None, help="LoRA module exclude patterns")
85
+ parser.add_argument(
86
+ "--save_merged_model",
87
+ type=str,
88
+ default=None,
89
+ help="Save merged model to path. If specified, no inference will be performed.",
90
+ )
91
+
92
+ # inference
93
+ parser.add_argument("--prompt", type=str, default=None, help="prompt for generation")
94
+ parser.add_argument(
95
+ "--negative_prompt",
96
+ type=str,
97
+ default=None,
98
+ help="negative prompt for generation, use default negative prompt if not specified",
99
+ )
100
+ parser.add_argument("--video_size", type=int, nargs=2, default=[256, 256], help="video size, height and width")
101
+ parser.add_argument("--video_length", type=int, default=None, help="video length, Default depends on task")
102
+ parser.add_argument("--fps", type=int, default=16, help="video fps, Default is 16")
103
+ parser.add_argument("--infer_steps", type=int, default=None, help="number of inference steps")
104
+ parser.add_argument("--save_path", type=str, required=True, help="path to save generated video")
105
+ parser.add_argument("--seed", type=int, default=None, help="Seed for evaluation.")
106
+ parser.add_argument(
107
+ "--cpu_noise", action="store_true", help="Use CPU to generate noise (compatible with ComfyUI). Default is False."
108
+ )
109
+ parser.add_argument(
110
+ "--guidance_scale",
111
+ type=float,
112
+ default=5.0,
113
+ help="Guidance scale for classifier free guidance. Default is 5.0.",
114
+ )
115
+ parser.add_argument("--video_path", type=str, default=None, help="path to video for video2video inference")
116
+ parser.add_argument("--image_path", type=str, default=None, help="path to image for image2video inference")
117
+ parser.add_argument("--end_image_path", type=str, default=None, help="path to end image for image2video inference")
118
+ parser.add_argument(
119
+ "--control_path",
120
+ type=str,
121
+ default=None,
122
+ help="path to control video for inference with controlnet. video file or directory with images",
123
+ )
124
+ parser.add_argument("--trim_tail_frames", type=int, default=0, help="trim tail N frames from the video before saving")
125
+ parser.add_argument(
126
+ "--cfg_skip_mode",
127
+ type=str,
128
+ default="none",
129
+ choices=["early", "late", "middle", "early_late", "alternate", "none"],
130
+ help="CFG skip mode. each mode skips different parts of the CFG. "
131
+ " early: initial steps, late: later steps, middle: middle steps, early_late: both early and late, alternate: alternate, none: no skip (default)",
132
+ )
133
+ parser.add_argument(
134
+ "--cfg_apply_ratio",
135
+ type=float,
136
+ default=None,
137
+ help="The ratio of steps to apply CFG (0.0 to 1.0). Default is None (apply all steps).",
138
+ )
139
+ parser.add_argument(
140
+ "--slg_layers", type=str, default=None, help="Skip block (layer) indices for SLG (Skip Layer Guidance), comma separated"
141
+ )
142
+ parser.add_argument(
143
+ "--slg_scale",
144
+ type=float,
145
+ default=3.0,
146
+ help="scale for SLG classifier free guidance. Default is 3.0. Ignored if slg_mode is None or uncond",
147
+ )
148
+ parser.add_argument("--slg_start", type=float, default=0.0, help="start ratio for inference steps for SLG. Default is 0.0.")
149
+ parser.add_argument("--slg_end", type=float, default=0.3, help="end ratio for inference steps for SLG. Default is 0.3.")
150
+ parser.add_argument(
151
+ "--slg_mode",
152
+ type=str,
153
+ default=None,
154
+ choices=["original", "uncond"],
155
+ help="SLG mode. original: same as SD3, uncond: replace uncond pred with SLG pred",
156
+ )
157
+
158
+ # Flow Matching
159
+ parser.add_argument(
160
+ "--flow_shift",
161
+ type=float,
162
+ default=None,
163
+ help="Shift factor for flow matching schedulers. Default depends on task.",
164
+ )
165
+
166
+ parser.add_argument("--fp8", action="store_true", help="use fp8 for DiT model")
167
+ parser.add_argument("--fp8_scaled", action="store_true", help="use scaled fp8 for DiT, only for fp8")
168
+ parser.add_argument("--fp8_fast", action="store_true", help="Enable fast FP8 arithmetic (RTX 4XXX+), only for fp8_scaled")
169
+ parser.add_argument("--fp8_t5", action="store_true", help="use fp8 for Text Encoder model")
170
+ parser.add_argument(
171
+ "--device", type=str, default=None, help="device to use for inference. If None, use CUDA if available, otherwise use CPU"
172
+ )
173
+ parser.add_argument(
174
+ "--attn_mode",
175
+ type=str,
176
+ default="torch",
177
+ choices=["flash", "flash2", "flash3", "torch", "sageattn", "xformers", "sdpa"],
178
+ help="attention mode",
179
+ )
180
+ parser.add_argument("--blocks_to_swap", type=int, default=0, help="number of blocks to swap in the model")
181
+ parser.add_argument(
182
+ "--output_type", type=str, default="video", choices=["video", "images", "latent", "both"], help="output type"
183
+ )
184
+ parser.add_argument("--no_metadata", action="store_true", help="do not save metadata")
185
+ parser.add_argument("--latent_path", type=str, nargs="*", default=None, help="path to latent for decode. no inference")
186
+ parser.add_argument("--lycoris", action="store_true", help="use lycoris for inference")
187
+ parser.add_argument("--compile", action="store_true", help="Enable torch.compile")
188
+ parser.add_argument(
189
+ "--compile_args",
190
+ nargs=4,
191
+ metavar=("BACKEND", "MODE", "DYNAMIC", "FULLGRAPH"),
192
+ default=["inductor", "max-autotune-no-cudagraphs", "False", "False"],
193
+ help="Torch.compile settings",
194
+ )
195
+
196
+ # New arguments for batch and interactive modes
197
+ parser.add_argument("--from_file", type=str, default=None, help="Read prompts from a file")
198
+ parser.add_argument("--interactive", action="store_true", help="Interactive mode: read prompts from console")
199
+
200
+ args = parser.parse_args()
201
+
202
+ # Validate arguments
203
+ if args.from_file and args.interactive:
204
+ raise ValueError("Cannot use both --from_file and --interactive at the same time")
205
+
206
+ if args.prompt is None and not args.from_file and not args.interactive and args.latent_path is None:
207
+ raise ValueError("Either --prompt, --from_file, --interactive, or --latent_path must be specified")
208
+
209
+ assert (args.latent_path is None or len(args.latent_path) == 0) or (
210
+ args.output_type == "images" or args.output_type == "video"
211
+ ), "latent_path is only supported for images or video output"
212
+
213
+ return args
214
+
215
+
216
+ def parse_prompt_line(line: str) -> Dict[str, Any]:
217
+ """Parse a prompt line into a dictionary of argument overrides
218
+
219
+ Args:
220
+ line: Prompt line with options
221
+
222
+ Returns:
223
+ Dict[str, Any]: Dictionary of argument overrides
224
+ """
225
+ # TODO common function with hv_train_network.line_to_prompt_dict
226
+ parts = line.split(" --")
227
+ prompt = parts[0].strip()
228
+
229
+ # Create dictionary of overrides
230
+ overrides = {"prompt": prompt}
231
+
232
+ for part in parts[1:]:
233
+ if not part.strip():
234
+ continue
235
+ option_parts = part.split(" ", 1)
236
+ option = option_parts[0].strip()
237
+ value = option_parts[1].strip() if len(option_parts) > 1 else ""
238
+
239
+ # Map options to argument names
240
+ if option == "w":
241
+ overrides["video_size_width"] = int(value)
242
+ elif option == "h":
243
+ overrides["video_size_height"] = int(value)
244
+ elif option == "f":
245
+ overrides["video_length"] = int(value)
246
+ elif option == "d":
247
+ overrides["seed"] = int(value)
248
+ elif option == "s":
249
+ overrides["infer_steps"] = int(value)
250
+ elif option == "g" or option == "l":
251
+ overrides["guidance_scale"] = float(value)
252
+ elif option == "fs":
253
+ overrides["flow_shift"] = float(value)
254
+ elif option == "i":
255
+ overrides["image_path"] = value
256
+ elif option == "cn":
257
+ overrides["control_path"] = value
258
+ elif option == "n":
259
+ overrides["negative_prompt"] = value
260
+
261
+ return overrides
262
+
263
+
264
+ def apply_overrides(args: argparse.Namespace, overrides: Dict[str, Any]) -> argparse.Namespace:
265
+ """Apply overrides to args
266
+
267
+ Args:
268
+ args: Original arguments
269
+ overrides: Dictionary of overrides
270
+
271
+ Returns:
272
+ argparse.Namespace: New arguments with overrides applied
273
+ """
274
+ args_copy = copy.deepcopy(args)
275
+
276
+ for key, value in overrides.items():
277
+ if key == "video_size_width":
278
+ args_copy.video_size[1] = value
279
+ elif key == "video_size_height":
280
+ args_copy.video_size[0] = value
281
+ else:
282
+ setattr(args_copy, key, value)
283
+
284
+ return args_copy
285
+
286
+
287
+ def get_task_defaults(task: str, size: Optional[Tuple[int, int]] = None) -> Tuple[int, float, int, bool]:
288
+ """Return default values for each task
289
+
290
+ Args:
291
+ task: task name (t2v, t2i, i2v etc.)
292
+ size: size of the video (width, height)
293
+
294
+ Returns:
295
+ Tuple[int, float, int, bool]: (infer_steps, flow_shift, video_length, needs_clip)
296
+ """
297
+ width, height = size if size else (0, 0)
298
+
299
+ if "t2i" in task:
300
+ return 50, 5.0, 1, False
301
+ elif "i2v" in task:
302
+ flow_shift = 3.0 if (width == 832 and height == 480) or (width == 480 and height == 832) else 5.0
303
+ return 40, flow_shift, 81, True
304
+ else: # t2v or default
305
+ return 50, 5.0, 81, False
306
+
307
+
308
+ def setup_args(args: argparse.Namespace) -> argparse.Namespace:
309
+ """Validate and set default values for optional arguments
310
+
311
+ Args:
312
+ args: command line arguments
313
+
314
+ Returns:
315
+ argparse.Namespace: updated arguments
316
+ """
317
+ # Get default values for the task
318
+ infer_steps, flow_shift, video_length, _ = get_task_defaults(args.task, tuple(args.video_size))
319
+
320
+ # Apply default values to unset arguments
321
+ if args.infer_steps is None:
322
+ args.infer_steps = infer_steps
323
+ if args.flow_shift is None:
324
+ args.flow_shift = flow_shift
325
+ if args.video_length is None:
326
+ args.video_length = video_length
327
+
328
+ # Force video_length to 1 for t2i tasks
329
+ if "t2i" in args.task:
330
+ assert args.video_length == 1, f"video_length should be 1 for task {args.task}"
331
+
332
+ # parse slg_layers
333
+ if args.slg_layers is not None:
334
+ args.slg_layers = list(map(int, args.slg_layers.split(",")))
335
+
336
+ return args
337
+
338
+
339
+ def check_inputs(args: argparse.Namespace) -> Tuple[int, int, int]:
340
+ """Validate video size and length
341
+
342
+ Args:
343
+ args: command line arguments
344
+
345
+ Returns:
346
+ Tuple[int, int, int]: (height, width, video_length)
347
+ """
348
+ height = args.video_size[0]
349
+ width = args.video_size[1]
350
+ size = f"{width}*{height}"
351
+
352
+ if size not in SUPPORTED_SIZES[args.task]:
353
+ logger.warning(f"Size {size} is not supported for task {args.task}. Supported sizes are {SUPPORTED_SIZES[args.task]}.")
354
+
355
+ video_length = args.video_length
356
+
357
+ if height % 8 != 0 or width % 8 != 0:
358
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
359
+
360
+ return height, width, video_length
361
+
362
+
363
+ def calculate_dimensions(video_size: Tuple[int, int], video_length: int, config) -> Tuple[Tuple[int, int, int, int], int]:
364
+ """calculate dimensions for the generation
365
+
366
+ Args:
367
+ video_size: video frame size (height, width)
368
+ video_length: number of frames in the video
369
+ config: model configuration
370
+
371
+ Returns:
372
+ Tuple[Tuple[int, int, int, int], int]:
373
+ ((channels, frames, height, width), seq_len)
374
+ """
375
+ height, width = video_size
376
+ frames = video_length
377
+
378
+ # calculate latent space dimensions
379
+ lat_f = (frames - 1) // config.vae_stride[0] + 1
380
+ lat_h = height // config.vae_stride[1]
381
+ lat_w = width // config.vae_stride[2]
382
+
383
+ # calculate sequence length
384
+ seq_len = math.ceil((lat_h * lat_w) / (config.patch_size[1] * config.patch_size[2]) * lat_f)
385
+
386
+ return ((16, lat_f, lat_h, lat_w), seq_len)
387
+
388
+
389
+ def load_vae(args: argparse.Namespace, config, device: torch.device, dtype: torch.dtype) -> WanVAE:
390
+ """load VAE model
391
+
392
+ Args:
393
+ args: command line arguments
394
+ config: model configuration
395
+ device: device to use
396
+ dtype: data type for the model
397
+
398
+ Returns:
399
+ WanVAE: loaded VAE model
400
+ """
401
+ vae_path = args.vae if args.vae is not None else os.path.join(args.ckpt_dir, config.vae_checkpoint)
402
+
403
+ logger.info(f"Loading VAE model from {vae_path}")
404
+ cache_device = torch.device("cpu") if args.vae_cache_cpu else None
405
+ vae = WanVAE(vae_path=vae_path, device=device, dtype=dtype, cache_device=cache_device)
406
+ return vae
407
+
408
+
409
+ def load_text_encoder(args: argparse.Namespace, config, device: torch.device) -> T5EncoderModel:
410
+ """load text encoder (T5) model
411
+
412
+ Args:
413
+ args: command line arguments
414
+ config: model configuration
415
+ device: device to use
416
+
417
+ Returns:
418
+ T5EncoderModel: loaded text encoder model
419
+ """
420
+ checkpoint_path = None if args.ckpt_dir is None else os.path.join(args.ckpt_dir, config.t5_checkpoint)
421
+ tokenizer_path = None if args.ckpt_dir is None else os.path.join(args.ckpt_dir, config.t5_tokenizer)
422
+
423
+ text_encoder = T5EncoderModel(
424
+ text_len=config.text_len,
425
+ dtype=config.t5_dtype,
426
+ device=device,
427
+ checkpoint_path=checkpoint_path,
428
+ tokenizer_path=tokenizer_path,
429
+ weight_path=args.t5,
430
+ fp8=args.fp8_t5,
431
+ )
432
+
433
+ return text_encoder
434
+
435
+
436
+ def load_clip_model(args: argparse.Namespace, config, device: torch.device) -> CLIPModel:
437
+ """load CLIP model (for I2V only)
438
+
439
+ Args:
440
+ args: command line arguments
441
+ config: model configuration
442
+ device: device to use
443
+
444
+ Returns:
445
+ CLIPModel: loaded CLIP model
446
+ """
447
+ checkpoint_path = None if args.ckpt_dir is None else os.path.join(args.ckpt_dir, config.clip_checkpoint)
448
+ tokenizer_path = None if args.ckpt_dir is None else os.path.join(args.ckpt_dir, config.clip_tokenizer)
449
+
450
+ clip = CLIPModel(
451
+ dtype=config.clip_dtype,
452
+ device=device,
453
+ checkpoint_path=checkpoint_path,
454
+ tokenizer_path=tokenizer_path,
455
+ weight_path=args.clip,
456
+ )
457
+
458
+ return clip
459
+
460
+
461
+ def load_dit_model(
462
+ args: argparse.Namespace,
463
+ config,
464
+ device: torch.device,
465
+ dit_dtype: torch.dtype,
466
+ dit_weight_dtype: Optional[torch.dtype] = None,
467
+ is_i2v: bool = False,
468
+ ) -> WanModel:
469
+ """load DiT model
470
+
471
+ Args:
472
+ args: command line arguments
473
+ config: model configuration
474
+ device: device to use
475
+ dit_dtype: data type for the model
476
+ dit_weight_dtype: data type for the model weights. None for as-is
477
+ is_i2v: I2V mode
478
+
479
+ Returns:
480
+ WanModel: loaded DiT model
481
+ """
482
+ loading_device = "cpu"
483
+ if args.blocks_to_swap == 0 and args.lora_weight is None and not args.fp8_scaled:
484
+ loading_device = device
485
+
486
+ loading_weight_dtype = dit_weight_dtype
487
+ if args.fp8_scaled or args.lora_weight is not None:
488
+ loading_weight_dtype = dit_dtype # load as-is
489
+
490
+ # do not fp8 optimize because we will merge LoRA weights
491
+ model = load_wan_model(config, device, args.dit, args.attn_mode, False, loading_device, loading_weight_dtype, False)
492
+
493
+ return model
494
+
495
+
496
+ def merge_lora_weights(lora_module: ModuleType, model: torch.nn.Module, args: argparse.Namespace, device: torch.device) -> None:
497
+ """merge LoRA weights to the model
498
+
499
+ Args:
500
+ model: DiT model
501
+ args: command line arguments
502
+ device: device to use
503
+ """
504
+ if args.lora_weight is None or len(args.lora_weight) == 0:
505
+ return
506
+
507
+ for i, lora_weight in enumerate(args.lora_weight):
508
+ if args.lora_multiplier is not None and len(args.lora_multiplier) > i:
509
+ lora_multiplier = args.lora_multiplier[i]
510
+ else:
511
+ lora_multiplier = 1.0
512
+
513
+ logger.info(f"Loading LoRA weights from {lora_weight} with multiplier {lora_multiplier}")
514
+ weights_sd = load_file(lora_weight)
515
+
516
+ # apply include/exclude patterns
517
+ original_key_count = len(weights_sd.keys())
518
+ if args.include_patterns is not None and len(args.include_patterns) > i:
519
+ include_pattern = args.include_patterns[i]
520
+ regex_include = re.compile(include_pattern)
521
+ weights_sd = {k: v for k, v in weights_sd.items() if regex_include.search(k)}
522
+ logger.info(f"Filtered keys with include pattern {include_pattern}: {original_key_count} -> {len(weights_sd.keys())}")
523
+ if args.exclude_patterns is not None and len(args.exclude_patterns) > i:
524
+ original_key_count_ex = len(weights_sd.keys())
525
+ exclude_pattern = args.exclude_patterns[i]
526
+ regex_exclude = re.compile(exclude_pattern)
527
+ weights_sd = {k: v for k, v in weights_sd.items() if not regex_exclude.search(k)}
528
+ logger.info(
529
+ f"Filtered keys with exclude pattern {exclude_pattern}: {original_key_count_ex} -> {len(weights_sd.keys())}"
530
+ )
531
+ if len(weights_sd) != original_key_count:
532
+ remaining_keys = list(set([k.split(".", 1)[0] for k in weights_sd.keys()]))
533
+ remaining_keys.sort()
534
+ logger.info(f"Remaining LoRA modules after filtering: {remaining_keys}")
535
+ if len(weights_sd) == 0:
536
+ logger.warning(f"No keys left after filtering.")
537
+
538
+ if args.lycoris:
539
+ lycoris_net, _ = create_network_from_weights(
540
+ multiplier=lora_multiplier,
541
+ file=None,
542
+ weights_sd=weights_sd,
543
+ unet=model,
544
+ text_encoder=None,
545
+ vae=None,
546
+ for_inference=True,
547
+ )
548
+ lycoris_net.merge_to(None, model, weights_sd, dtype=None, device=device)
549
+ else:
550
+ network = lora_module.create_arch_network_from_weights(lora_multiplier, weights_sd, unet=model, for_inference=True)
551
+ network.merge_to(None, model, weights_sd, device=device, non_blocking=True)
552
+
553
+ synchronize_device(device)
554
+ logger.info("LoRA weights loaded")
555
+
556
+ # save model here before casting to dit_weight_dtype
557
+ if args.save_merged_model:
558
+ logger.info(f"Saving merged model to {args.save_merged_model}")
559
+ mem_eff_save_file(model.state_dict(), args.save_merged_model) # save_file needs a lot of memory
560
+ logger.info("Merged model saved")
561
+
562
+
563
+ def optimize_model(
564
+ model: WanModel, args: argparse.Namespace, device: torch.device, dit_dtype: torch.dtype, dit_weight_dtype: torch.dtype
565
+ ) -> None:
566
+ """optimize the model (FP8 conversion, device move etc.)
567
+
568
+ Args:
569
+ model: dit model
570
+ args: command line arguments
571
+ device: device to use
572
+ dit_dtype: dtype for the model
573
+ dit_weight_dtype: dtype for the model weights
574
+ """
575
+ if args.fp8_scaled:
576
+ # load state dict as-is and optimize to fp8
577
+ state_dict = model.state_dict()
578
+
579
+ # if no blocks to swap, we can move the weights to GPU after optimization on GPU (omit redundant CPU->GPU copy)
580
+ move_to_device = args.blocks_to_swap == 0 # if blocks_to_swap > 0, we will keep the model on CPU
581
+ state_dict = model.fp8_optimization(state_dict, device, move_to_device, use_scaled_mm=args.fp8_fast)
582
+
583
+ info = model.load_state_dict(state_dict, strict=True, assign=True)
584
+ logger.info(f"Loaded FP8 optimized weights: {info}")
585
+
586
+ if args.blocks_to_swap == 0:
587
+ model.to(device) # make sure all parameters are on the right device (e.g. RoPE etc.)
588
+ else:
589
+ # simple cast to dit_dtype
590
+ target_dtype = None # load as-is (dit_weight_dtype == dtype of the weights in state_dict)
591
+ target_device = None
592
+
593
+ if dit_weight_dtype is not None: # in case of args.fp8 and not args.fp8_scaled
594
+ logger.info(f"Convert model to {dit_weight_dtype}")
595
+ target_dtype = dit_weight_dtype
596
+
597
+ if args.blocks_to_swap == 0:
598
+ logger.info(f"Move model to device: {device}")
599
+ target_device = device
600
+
601
+ model.to(target_device, target_dtype) # move and cast at the same time. this reduces redundant copy operations
602
+
603
+ if args.compile:
604
+ compile_backend, compile_mode, compile_dynamic, compile_fullgraph = args.compile_args
605
+ logger.info(
606
+ f"Torch Compiling[Backend: {compile_backend}; Mode: {compile_mode}; Dynamic: {compile_dynamic}; Fullgraph: {compile_fullgraph}]"
607
+ )
608
+ torch._dynamo.config.cache_size_limit = 32
609
+ for i in range(len(model.blocks)):
610
+ model.blocks[i] = torch.compile(
611
+ model.blocks[i],
612
+ backend=compile_backend,
613
+ mode=compile_mode,
614
+ dynamic=compile_dynamic.lower() in "true",
615
+ fullgraph=compile_fullgraph.lower() in "true",
616
+ )
617
+
618
+ if args.blocks_to_swap > 0:
619
+ logger.info(f"Enable swap {args.blocks_to_swap} blocks to CPU from device: {device}")
620
+ model.enable_block_swap(args.blocks_to_swap, device, supports_backward=False)
621
+ model.move_to_device_except_swap_blocks(device)
622
+ model.prepare_block_swap_before_forward()
623
+ else:
624
+ # make sure the model is on the right device
625
+ model.to(device)
626
+
627
+ model.eval().requires_grad_(False)
628
+ clean_memory_on_device(device)
629
+
630
+
631
+ def prepare_t2v_inputs(
632
+ args: argparse.Namespace,
633
+ config,
634
+ accelerator: Accelerator,
635
+ device: torch.device,
636
+ vae: Optional[WanVAE] = None,
637
+ encoded_context: Optional[Dict] = None,
638
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple[dict, dict]]:
639
+ """Prepare inputs for T2V
640
+
641
+ Args:
642
+ args: command line arguments
643
+ config: model configuration
644
+ accelerator: Accelerator instance
645
+ device: device to use
646
+ vae: VAE model for control video encoding
647
+ encoded_context: Pre-encoded text context
648
+
649
+ Returns:
650
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple[dict, dict]]:
651
+ (noise, context, context_null, (arg_c, arg_null))
652
+ """
653
+ # Prepare inputs for T2V
654
+ # calculate dimensions and sequence length
655
+ height, width = args.video_size
656
+ frames = args.video_length
657
+ (_, lat_f, lat_h, lat_w), seq_len = calculate_dimensions(args.video_size, args.video_length, config)
658
+ target_shape = (16, lat_f, lat_h, lat_w)
659
+
660
+ # configure negative prompt
661
+ n_prompt = args.negative_prompt if args.negative_prompt else config.sample_neg_prompt
662
+
663
+ # set seed
664
+ seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1)
665
+ if not args.cpu_noise:
666
+ seed_g = torch.Generator(device=device)
667
+ seed_g.manual_seed(seed)
668
+ else:
669
+ # ComfyUI compatible noise
670
+ seed_g = torch.manual_seed(seed)
671
+
672
+ if encoded_context is None:
673
+ # load text encoder
674
+ text_encoder = load_text_encoder(args, config, device)
675
+ text_encoder.model.to(device)
676
+
677
+ # encode prompt
678
+ with torch.no_grad():
679
+ if args.fp8_t5:
680
+ with torch.amp.autocast(device_type=device.type, dtype=config.t5_dtype):
681
+ context = text_encoder([args.prompt], device)
682
+ context_null = text_encoder([n_prompt], device)
683
+ else:
684
+ context = text_encoder([args.prompt], device)
685
+ context_null = text_encoder([n_prompt], device)
686
+
687
+ # free text encoder and clean memory
688
+ del text_encoder
689
+ clean_memory_on_device(device)
690
+ else:
691
+ # Use pre-encoded context
692
+ context = encoded_context["context"]
693
+ context_null = encoded_context["context_null"]
694
+
695
+ # Fun-Control: encode control video to latent space
696
+ if config.is_fun_control:
697
+ # TODO use same resizing as for image
698
+ logger.info(f"Encoding control video to latent space")
699
+ # C, F, H, W
700
+ control_video = load_control_video(args.control_path, frames, height, width).to(device)
701
+ vae.to_device(device)
702
+ with torch.autocast(device_type=device.type, dtype=vae.dtype), torch.no_grad():
703
+ control_latent = vae.encode([control_video])[0]
704
+ y = torch.concat([control_latent, torch.zeros_like(control_latent)], dim=0) # add control video latent
705
+ vae.to_device("cpu")
706
+ else:
707
+ y = None
708
+
709
+ # generate noise
710
+ noise = torch.randn(target_shape, dtype=torch.float32, generator=seed_g, device=device if not args.cpu_noise else "cpu")
711
+ noise = noise.to(device)
712
+
713
+ # prepare model input arguments
714
+ arg_c = {"context": context, "seq_len": seq_len}
715
+ arg_null = {"context": context_null, "seq_len": seq_len}
716
+ if y is not None:
717
+ arg_c["y"] = [y]
718
+ arg_null["y"] = [y]
719
+
720
+ return noise, context, context_null, (arg_c, arg_null)
721
+
722
+
723
+ def prepare_i2v_inputs(
724
+ args: argparse.Namespace,
725
+ config,
726
+ accelerator: Accelerator,
727
+ device: torch.device,
728
+ vae: WanVAE,
729
+ encoded_context: Optional[Dict] = None,
730
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[dict, dict]]:
731
+ """Prepare inputs for I2V
732
+
733
+ Args:
734
+ args: command line arguments
735
+ config: model configuration
736
+ accelerator: Accelerator instance
737
+ device: device to use
738
+ vae: VAE model, used for image encoding
739
+ encoded_context: Pre-encoded text context
740
+
741
+ Returns:
742
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[dict, dict]]:
743
+ (noise, context, context_null, y, (arg_c, arg_null))
744
+ """
745
+ # get video dimensions
746
+ height, width = args.video_size
747
+ frames = args.video_length
748
+ max_area = width * height
749
+
750
+ # load image
751
+ img = Image.open(args.image_path).convert("RGB")
752
+
753
+ # convert to numpy
754
+ img_cv2 = np.array(img) # PIL to numpy
755
+
756
+ # convert to tensor (-1 to 1)
757
+ img_tensor = TF.to_tensor(img).sub_(0.5).div_(0.5).to(device)
758
+
759
+ # end frame image
760
+ if args.end_image_path is not None:
761
+ end_img = Image.open(args.end_image_path).convert("RGB")
762
+ end_img_cv2 = np.array(end_img) # PIL to numpy
763
+ else:
764
+ end_img = None
765
+ end_img_cv2 = None
766
+ has_end_image = end_img is not None
767
+
768
+ # calculate latent dimensions: keep aspect ratio
769
+ height, width = img_tensor.shape[1:]
770
+ aspect_ratio = height / width
771
+ lat_h = round(np.sqrt(max_area * aspect_ratio) // config.vae_stride[1] // config.patch_size[1] * config.patch_size[1])
772
+ lat_w = round(np.sqrt(max_area / aspect_ratio) // config.vae_stride[2] // config.patch_size[2] * config.patch_size[2])
773
+ height = lat_h * config.vae_stride[1]
774
+ width = lat_w * config.vae_stride[2]
775
+ lat_f = (frames - 1) // config.vae_stride[0] + 1 # size of latent frames
776
+ max_seq_len = (lat_f + (1 if has_end_image else 0)) * lat_h * lat_w // (config.patch_size[1] * config.patch_size[2])
777
+
778
+ # set seed
779
+ seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1)
780
+ if not args.cpu_noise:
781
+ seed_g = torch.Generator(device=device)
782
+ seed_g.manual_seed(seed)
783
+ else:
784
+ # ComfyUI compatible noise
785
+ seed_g = torch.manual_seed(seed)
786
+
787
+ # generate noise
788
+ noise = torch.randn(
789
+ 16,
790
+ lat_f + (1 if has_end_image else 0),
791
+ lat_h,
792
+ lat_w,
793
+ dtype=torch.float32,
794
+ generator=seed_g,
795
+ device=device if not args.cpu_noise else "cpu",
796
+ )
797
+ noise = noise.to(device)
798
+
799
+ # configure negative prompt
800
+ n_prompt = args.negative_prompt if args.negative_prompt else config.sample_neg_prompt
801
+
802
+ if encoded_context is None:
803
+ # load text encoder
804
+ text_encoder = load_text_encoder(args, config, device)
805
+ text_encoder.model.to(device)
806
+
807
+ # encode prompt
808
+ with torch.no_grad():
809
+ if args.fp8_t5:
810
+ with torch.amp.autocast(device_type=device.type, dtype=config.t5_dtype):
811
+ context = text_encoder([args.prompt], device)
812
+ context_null = text_encoder([n_prompt], device)
813
+ else:
814
+ context = text_encoder([args.prompt], device)
815
+ context_null = text_encoder([n_prompt], device)
816
+
817
+ # free text encoder and clean memory
818
+ del text_encoder
819
+ clean_memory_on_device(device)
820
+
821
+ # load CLIP model
822
+ clip = load_clip_model(args, config, device)
823
+ clip.model.to(device)
824
+
825
+ # encode image to CLIP context
826
+ logger.info(f"Encoding image to CLIP context")
827
+ with torch.amp.autocast(device_type=device.type, dtype=torch.float16), torch.no_grad():
828
+ clip_context = clip.visual([img_tensor[:, None, :, :]])
829
+ logger.info(f"Encoding complete")
830
+
831
+ # free CLIP model and clean memory
832
+ del clip
833
+ clean_memory_on_device(device)
834
+ else:
835
+ # Use pre-encoded context
836
+ context = encoded_context["context"]
837
+ context_null = encoded_context["context_null"]
838
+ clip_context = encoded_context["clip_context"]
839
+
840
+ # encode image to latent space with VAE
841
+ logger.info(f"Encoding image to latent space")
842
+ vae.to_device(device)
843
+
844
+ # resize image
845
+ interpolation = cv2.INTER_AREA if height < img_cv2.shape[0] else cv2.INTER_CUBIC
846
+ img_resized = cv2.resize(img_cv2, (width, height), interpolation=interpolation)
847
+ img_resized = TF.to_tensor(img_resized).sub_(0.5).div_(0.5).to(device) # -1 to 1, CHW
848
+ img_resized = img_resized.unsqueeze(1) # CFHW
849
+
850
+ if has_end_image:
851
+ interpolation = cv2.INTER_AREA if height < end_img_cv2.shape[1] else cv2.INTER_CUBIC
852
+ end_img_resized = cv2.resize(end_img_cv2, (width, height), interpolation=interpolation)
853
+ end_img_resized = TF.to_tensor(end_img_resized).sub_(0.5).div_(0.5).to(device) # -1 to 1, CHW
854
+ end_img_resized = end_img_resized.unsqueeze(1) # CFHW
855
+
856
+ # create mask for the first frame
857
+ msk = torch.zeros(4, lat_f + (1 if has_end_image else 0), lat_h, lat_w, device=device)
858
+ msk[:, 0] = 1
859
+ if has_end_image:
860
+ msk[:, -1] = 1
861
+
862
+ # encode image to latent space
863
+ with accelerator.autocast(), torch.no_grad():
864
+ # padding to match the required number of frames
865
+ padding_frames = frames - 1 # the first frame is image
866
+ img_resized = torch.concat([img_resized, torch.zeros(3, padding_frames, height, width, device=device)], dim=1)
867
+ y = vae.encode([img_resized])[0]
868
+
869
+ if has_end_image:
870
+ y_end = vae.encode([end_img_resized])[0]
871
+ y = torch.concat([y, y_end], dim=1) # add end frame
872
+
873
+ y = torch.concat([msk, y])
874
+ logger.info(f"Encoding complete")
875
+
876
+ # Fun-Control: encode control video to latent space
877
+ if config.is_fun_control:
878
+ # TODO use same resizing as for image
879
+ logger.info(f"Encoding control video to latent space")
880
+ # C, F, H, W
881
+ control_video = load_control_video(args.control_path, frames + (1 if has_end_image else 0), height, width).to(device)
882
+ with accelerator.autocast(), torch.no_grad():
883
+ control_latent = vae.encode([control_video])[0]
884
+ y = y[msk.shape[0] :] # remove mask because Fun-Control does not need it
885
+ if has_end_image:
886
+ y[:, 1:-1] = 0 # remove image latent except first and last frame. according to WanVideoWrapper, this doesn't work
887
+ else:
888
+ y[:, 1:] = 0 # remove image latent except first frame
889
+ y = torch.concat([control_latent, y], dim=0) # add control video latent
890
+
891
+ # prepare model input arguments
892
+ arg_c = {
893
+ "context": [context[0]],
894
+ "clip_fea": clip_context,
895
+ "seq_len": max_seq_len,
896
+ "y": [y],
897
+ }
898
+
899
+ arg_null = {
900
+ "context": context_null,
901
+ "clip_fea": clip_context,
902
+ "seq_len": max_seq_len,
903
+ "y": [y],
904
+ }
905
+
906
+ vae.to_device("cpu") # move VAE to CPU to save memory
907
+ clean_memory_on_device(device)
908
+
909
+ return noise, context, context_null, y, (arg_c, arg_null)
910
+
911
+
912
+ def load_control_video(control_path: str, frames: int, height: int, width: int) -> torch.Tensor:
913
+ """load control video to latent space
914
+
915
+ Args:
916
+ control_path: path to control video
917
+ frames: number of frames in the video
918
+ height: height of the video
919
+ width: width of the video
920
+
921
+ Returns:
922
+ torch.Tensor: control video latent, CFHW
923
+ """
924
+ logger.info(f"Load control video from {control_path}")
925
+ video = load_video(control_path, 0, frames, bucket_reso=(width, height)) # list of frames
926
+ if len(video) < frames:
927
+ raise ValueError(f"Video length is less than {frames}")
928
+ # video = np.stack(video, axis=0) # F, H, W, C
929
+ video = torch.stack([TF.to_tensor(frame).sub_(0.5).div_(0.5) for frame in video], dim=0) # F, C, H, W, -1 to 1
930
+ video = video.permute(1, 0, 2, 3) # C, F, H, W
931
+ return video
932
+
933
+
934
+ def setup_scheduler(args: argparse.Namespace, config, device: torch.device) -> Tuple[Any, torch.Tensor]:
935
+ """setup scheduler for sampling
936
+
937
+ Args:
938
+ args: command line arguments
939
+ config: model configuration
940
+ device: device to use
941
+
942
+ Returns:
943
+ Tuple[Any, torch.Tensor]: (scheduler, timesteps)
944
+ """
945
+ if args.sample_solver == "unipc":
946
+ scheduler = FlowUniPCMultistepScheduler(num_train_timesteps=config.num_train_timesteps, shift=1, use_dynamic_shifting=False)
947
+ scheduler.set_timesteps(args.infer_steps, device=device, shift=args.flow_shift)
948
+ timesteps = scheduler.timesteps
949
+ elif args.sample_solver == "dpm++":
950
+ scheduler = FlowDPMSolverMultistepScheduler(
951
+ num_train_timesteps=config.num_train_timesteps, shift=1, use_dynamic_shifting=False
952
+ )
953
+ sampling_sigmas = get_sampling_sigmas(args.infer_steps, args.flow_shift)
954
+ timesteps, _ = retrieve_timesteps(scheduler, device=device, sigmas=sampling_sigmas)
955
+ elif args.sample_solver == "vanilla":
956
+ scheduler = FlowMatchDiscreteScheduler(num_train_timesteps=config.num_train_timesteps, shift=args.flow_shift)
957
+ scheduler.set_timesteps(args.infer_steps, device=device)
958
+ timesteps = scheduler.timesteps
959
+
960
+ # FlowMatchDiscreteScheduler does not support generator argument in step method
961
+ org_step = scheduler.step
962
+
963
+ def step_wrapper(
964
+ model_output: torch.Tensor,
965
+ timestep: Union[int, torch.Tensor],
966
+ sample: torch.Tensor,
967
+ return_dict: bool = True,
968
+ generator=None,
969
+ ):
970
+ return org_step(model_output, timestep, sample, return_dict=return_dict)
971
+
972
+ scheduler.step = step_wrapper
973
+ else:
974
+ raise NotImplementedError("Unsupported solver.")
975
+
976
+ return scheduler, timesteps
977
+
978
+
979
+ def run_sampling(
980
+ model: WanModel,
981
+ noise: torch.Tensor,
982
+ scheduler: Any,
983
+ timesteps: torch.Tensor,
984
+ args: argparse.Namespace,
985
+ inputs: Tuple[dict, dict],
986
+ device: torch.device,
987
+ seed_g: torch.Generator,
988
+ accelerator: Accelerator,
989
+ is_i2v: bool = False,
990
+ use_cpu_offload: bool = True,
991
+ ) -> torch.Tensor:
992
+ """run sampling
993
+ Args:
994
+ model: dit model
995
+ noise: initial noise
996
+ scheduler: scheduler for sampling
997
+ timesteps: time steps for sampling
998
+ args: command line arguments
999
+ inputs: model input (arg_c, arg_null)
1000
+ device: device to use
1001
+ seed_g: random generator
1002
+ accelerator: Accelerator instance
1003
+ is_i2v: I2V mode (False means T2V mode)
1004
+ use_cpu_offload: Whether to offload tensors to CPU during processing
1005
+ Returns:
1006
+ torch.Tensor: generated latent
1007
+ """
1008
+ arg_c, arg_null = inputs
1009
+
1010
+ latent = noise
1011
+ latent_storage_device = device if not use_cpu_offload else "cpu"
1012
+ latent = latent.to(latent_storage_device)
1013
+
1014
+ # cfg skip
1015
+ apply_cfg_array = []
1016
+ num_timesteps = len(timesteps)
1017
+
1018
+ if args.cfg_skip_mode != "none" and args.cfg_apply_ratio is not None:
1019
+ # Calculate thresholds based on cfg_apply_ratio
1020
+ apply_steps = int(num_timesteps * args.cfg_apply_ratio)
1021
+
1022
+ if args.cfg_skip_mode == "early":
1023
+ # Skip CFG in early steps, apply in late steps
1024
+ start_index = num_timesteps - apply_steps
1025
+ end_index = num_timesteps
1026
+ elif args.cfg_skip_mode == "late":
1027
+ # Skip CFG in late steps, apply in early steps
1028
+ start_index = 0
1029
+ end_index = apply_steps
1030
+ elif args.cfg_skip_mode == "early_late":
1031
+ # Skip CFG in early and late steps, apply in middle steps
1032
+ start_index = (num_timesteps - apply_steps) // 2
1033
+ end_index = start_index + apply_steps
1034
+ elif args.cfg_skip_mode == "middle":
1035
+ # Skip CFG in middle steps, apply in early and late steps
1036
+ skip_steps = num_timesteps - apply_steps
1037
+ middle_start = (num_timesteps - skip_steps) // 2
1038
+ middle_end = middle_start + skip_steps
1039
+
1040
+ w = 0.0
1041
+ for step_idx in range(num_timesteps):
1042
+ if args.cfg_skip_mode == "alternate":
1043
+ # accumulate w and apply CFG when w >= 1.0
1044
+ w += args.cfg_apply_ratio
1045
+ apply = w >= 1.0
1046
+ if apply:
1047
+ w -= 1.0
1048
+ elif args.cfg_skip_mode == "middle":
1049
+ # Skip CFG in early and late steps, apply in middle steps
1050
+ apply = step_idx < middle_start or step_idx >= middle_end
1051
+ else:
1052
+ # Apply CFG on some steps based on ratio
1053
+ apply = step_idx >= start_index and step_idx < end_index
1054
+
1055
+ apply_cfg_array.append(apply)
1056
+
1057
+ pattern = ["A" if apply else "S" for apply in apply_cfg_array]
1058
+ pattern = "".join(pattern)
1059
+ logger.info(f"CFG skip mode: {args.cfg_skip_mode}, apply ratio: {args.cfg_apply_ratio}, pattern: {pattern}")
1060
+ else:
1061
+ # Apply CFG on all steps
1062
+ apply_cfg_array = [True] * num_timesteps
1063
+
1064
+ # SLG original implementation is based on https://github.com/Stability-AI/sd3.5/blob/main/sd3_impls.py
1065
+ slg_start_step = int(args.slg_start * num_timesteps)
1066
+ slg_end_step = int(args.slg_end * num_timesteps)
1067
+
1068
+ for i, t in enumerate(tqdm(timesteps)):
1069
+ # latent is on CPU if use_cpu_offload is True
1070
+ latent_model_input = [latent.to(device)]
1071
+ timestep = torch.stack([t]).to(device)
1072
+
1073
+ with accelerator.autocast(), torch.no_grad():
1074
+ noise_pred_cond = model(latent_model_input, t=timestep, **arg_c)[0].to(latent_storage_device)
1075
+
1076
+ apply_cfg = apply_cfg_array[i] # apply CFG or not
1077
+ if apply_cfg:
1078
+ apply_slg = i >= slg_start_step and i < slg_end_step
1079
+ # print(f"Applying SLG: {apply_slg}, i: {i}, slg_start_step: {slg_start_step}, slg_end_step: {slg_end_step}")
1080
+ if args.slg_mode == "original" and apply_slg:
1081
+ noise_pred_uncond = model(latent_model_input, t=timestep, **arg_null)[0].to(latent_storage_device)
1082
+
1083
+ # apply guidance
1084
+ # SD3 formula: scaled = neg_out + (pos_out - neg_out) * cond_scale
1085
+ noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_cond - noise_pred_uncond)
1086
+
1087
+ # calculate skip layer out
1088
+ skip_layer_out = model(latent_model_input, t=timestep, skip_block_indices=args.slg_layers, **arg_null)[0].to(
1089
+ latent_storage_device
1090
+ )
1091
+
1092
+ # apply skip layer guidance
1093
+ # SD3 formula: scaled = scaled + (pos_out - skip_layer_out) * self.slg
1094
+ noise_pred = noise_pred + args.slg_scale * (noise_pred_cond - skip_layer_out)
1095
+ elif args.slg_mode == "uncond" and apply_slg:
1096
+ # noise_pred_uncond is skip layer out
1097
+ noise_pred_uncond = model(latent_model_input, t=timestep, skip_block_indices=args.slg_layers, **arg_null)[0].to(
1098
+ latent_storage_device
1099
+ )
1100
+
1101
+ # apply guidance
1102
+ noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_cond - noise_pred_uncond)
1103
+
1104
+ else:
1105
+ # normal guidance
1106
+ noise_pred_uncond = model(latent_model_input, t=timestep, **arg_null)[0].to(latent_storage_device)
1107
+
1108
+ # apply guidance
1109
+ noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_cond - noise_pred_uncond)
1110
+ else:
1111
+ noise_pred = noise_pred_cond
1112
+
1113
+ # step
1114
+ latent_input = latent.unsqueeze(0)
1115
+ temp_x0 = scheduler.step(noise_pred.unsqueeze(0), t, latent_input, return_dict=False, generator=seed_g)[0]
1116
+
1117
+ # update latent
1118
+ latent = temp_x0.squeeze(0)
1119
+
1120
+ return latent
1121
+
1122
+
1123
+ def generate(args: argparse.Namespace, gen_settings: GenerationSettings, shared_models: Optional[Dict] = None) -> torch.Tensor:
1124
+ """main function for generation
1125
+
1126
+ Args:
1127
+ args: command line arguments
1128
+ shared_models: dictionary containing pre-loaded models and encoded data
1129
+
1130
+ Returns:
1131
+ torch.Tensor: generated latent
1132
+ """
1133
+ device, cfg, dit_dtype, dit_weight_dtype, vae_dtype = (
1134
+ gen_settings.device,
1135
+ gen_settings.cfg,
1136
+ gen_settings.dit_dtype,
1137
+ gen_settings.dit_weight_dtype,
1138
+ gen_settings.vae_dtype,
1139
+ )
1140
+
1141
+ # prepare accelerator
1142
+ mixed_precision = "bf16" if dit_dtype == torch.bfloat16 else "fp16"
1143
+ accelerator = accelerate.Accelerator(mixed_precision=mixed_precision)
1144
+
1145
+ # I2V or T2V
1146
+ is_i2v = "i2v" in args.task
1147
+
1148
+ # prepare seed
1149
+ seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1)
1150
+ args.seed = seed # set seed to args for saving
1151
+
1152
+ # Check if we have shared models
1153
+ if shared_models is not None:
1154
+ # Use shared models and encoded data
1155
+ vae = shared_models.get("vae")
1156
+ model = shared_models.get("model")
1157
+ encoded_context = shared_models.get("encoded_contexts", {}).get(args.prompt)
1158
+
1159
+ # prepare inputs
1160
+ if is_i2v:
1161
+ # I2V
1162
+ noise, context, context_null, y, inputs = prepare_i2v_inputs(args, cfg, accelerator, device, vae, encoded_context)
1163
+ else:
1164
+ # T2V
1165
+ noise, context, context_null, inputs = prepare_t2v_inputs(args, cfg, accelerator, device, vae, encoded_context)
1166
+ else:
1167
+ # prepare inputs without shared models
1168
+ if is_i2v:
1169
+ # I2V: need text encoder, VAE and CLIP
1170
+ vae = load_vae(args, cfg, device, vae_dtype)
1171
+ noise, context, context_null, y, inputs = prepare_i2v_inputs(args, cfg, accelerator, device, vae)
1172
+ # vae is on CPU after prepare_i2v_inputs
1173
+ else:
1174
+ # T2V: need text encoder
1175
+ vae = None
1176
+ if cfg.is_fun_control:
1177
+ # Fun-Control: need VAE for encoding control video
1178
+ vae = load_vae(args, cfg, device, vae_dtype)
1179
+ noise, context, context_null, inputs = prepare_t2v_inputs(args, cfg, accelerator, device, vae)
1180
+
1181
+ # load DiT model
1182
+ model = load_dit_model(args, cfg, device, dit_dtype, dit_weight_dtype, is_i2v)
1183
+
1184
+ # merge LoRA weights
1185
+ if args.lora_weight is not None and len(args.lora_weight) > 0:
1186
+ merge_lora_weights(lora_wan, model, args, device)
1187
+
1188
+ # if we only want to save the model, we can skip the rest
1189
+ if args.save_merged_model:
1190
+ return None
1191
+
1192
+ # optimize model: fp8 conversion, block swap etc.
1193
+ optimize_model(model, args, device, dit_dtype, dit_weight_dtype)
1194
+
1195
+ # setup scheduler
1196
+ scheduler, timesteps = setup_scheduler(args, cfg, device)
1197
+
1198
+ # set random generator
1199
+ seed_g = torch.Generator(device=device)
1200
+ seed_g.manual_seed(seed)
1201
+
1202
+ # run sampling
1203
+ latent = run_sampling(model, noise, scheduler, timesteps, args, inputs, device, seed_g, accelerator, is_i2v)
1204
+
1205
+ # Only clean up shared models if they were created within this function
1206
+ if shared_models is None:
1207
+ # free memory
1208
+ del model
1209
+ del scheduler
1210
+ synchronize_device(device)
1211
+
1212
+ # wait for 5 seconds until block swap is done
1213
+ logger.info("Waiting for 5 seconds to finish block swap")
1214
+ time.sleep(5)
1215
+
1216
+ gc.collect()
1217
+ clean_memory_on_device(device)
1218
+
1219
+ # save VAE model for decoding
1220
+ if vae is None:
1221
+ args._vae = None
1222
+ else:
1223
+ args._vae = vae
1224
+
1225
+ return latent
1226
+
1227
+
1228
+ def decode_latent(latent: torch.Tensor, args: argparse.Namespace, cfg) -> torch.Tensor:
1229
+ """decode latent
1230
+
1231
+ Args:
1232
+ latent: latent tensor
1233
+ args: command line arguments
1234
+ cfg: model configuration
1235
+
1236
+ Returns:
1237
+ torch.Tensor: decoded video or image
1238
+ """
1239
+ device = torch.device(args.device)
1240
+
1241
+ # load VAE model or use the one from the generation
1242
+ vae_dtype = str_to_dtype(args.vae_dtype) if args.vae_dtype is not None else torch.bfloat16
1243
+ if hasattr(args, "_vae") and args._vae is not None:
1244
+ vae = args._vae
1245
+ else:
1246
+ vae = load_vae(args, cfg, device, vae_dtype)
1247
+
1248
+ vae.to_device(device)
1249
+
1250
+ logger.info(f"Decoding video from latents: {latent.shape}")
1251
+ x0 = latent.to(device)
1252
+
1253
+ with torch.autocast(device_type=device.type, dtype=vae_dtype), torch.no_grad():
1254
+ videos = vae.decode(x0)
1255
+
1256
+ # some tail frames may be corrupted when end frame is used, we add an option to remove them
1257
+ if args.trim_tail_frames:
1258
+ videos[0] = videos[0][:, : -args.trim_tail_frames]
1259
+
1260
+ logger.info(f"Decoding complete")
1261
+ video = videos[0]
1262
+ del videos
1263
+ video = video.to(torch.float32).cpu()
1264
+
1265
+ return video
1266
+
1267
+
1268
+ def save_latent(latent: torch.Tensor, args: argparse.Namespace, height: int, width: int) -> str:
1269
+ """Save latent to file
1270
+
1271
+ Args:
1272
+ latent: latent tensor
1273
+ args: command line arguments
1274
+ height: height of frame
1275
+ width: width of frame
1276
+
1277
+ Returns:
1278
+ str: Path to saved latent file
1279
+ """
1280
+ save_path = args.save_path
1281
+ os.makedirs(save_path, exist_ok=True)
1282
+ time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S")
1283
+
1284
+ seed = args.seed
1285
+ video_length = args.video_length
1286
+ latent_path = f"{save_path}/{time_flag}_{seed}_latent.safetensors"
1287
+
1288
+ if args.no_metadata:
1289
+ metadata = None
1290
+ else:
1291
+ metadata = {
1292
+ "seeds": f"{seed}",
1293
+ "prompt": f"{args.prompt}",
1294
+ "height": f"{height}",
1295
+ "width": f"{width}",
1296
+ "video_length": f"{video_length}",
1297
+ "infer_steps": f"{args.infer_steps}",
1298
+ "guidance_scale": f"{args.guidance_scale}",
1299
+ }
1300
+ if args.negative_prompt is not None:
1301
+ metadata["negative_prompt"] = f"{args.negative_prompt}"
1302
+
1303
+ sd = {"latent": latent}
1304
+ save_file(sd, latent_path, metadata=metadata)
1305
+ logger.info(f"Latent saved to: {latent_path}")
1306
+
1307
+ return latent_path
1308
+
1309
+
1310
+ def save_video(video: torch.Tensor, args: argparse.Namespace, original_base_name: Optional[str] = None) -> str:
1311
+ """Save video to file
1312
+
1313
+ Args:
1314
+ video: Video tensor
1315
+ args: command line arguments
1316
+ original_base_name: Original base name (if latents are loaded from files)
1317
+
1318
+ Returns:
1319
+ str: Path to saved video file
1320
+ """
1321
+ save_path = args.save_path
1322
+ os.makedirs(save_path, exist_ok=True)
1323
+ time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S")
1324
+
1325
+ seed = args.seed
1326
+ original_name = "" if original_base_name is None else f"_{original_base_name}"
1327
+ video_path = f"{save_path}/{time_flag}_{seed}{original_name}.mp4"
1328
+
1329
+ video = video.unsqueeze(0)
1330
+ save_videos_grid(video, video_path, fps=args.fps, rescale=True)
1331
+ logger.info(f"Video saved to: {video_path}")
1332
+
1333
+ return video_path
1334
+
1335
+
1336
+ def save_images(sample: torch.Tensor, args: argparse.Namespace, original_base_name: Optional[str] = None) -> str:
1337
+ """Save images to directory
1338
+
1339
+ Args:
1340
+ sample: Video tensor
1341
+ args: command line arguments
1342
+ original_base_name: Original base name (if latents are loaded from files)
1343
+
1344
+ Returns:
1345
+ str: Path to saved images directory
1346
+ """
1347
+ save_path = args.save_path
1348
+ os.makedirs(save_path, exist_ok=True)
1349
+ time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S")
1350
+
1351
+ seed = args.seed
1352
+ original_name = "" if original_base_name is None else f"_{original_base_name}"
1353
+ image_name = f"{time_flag}_{seed}{original_name}"
1354
+ sample = sample.unsqueeze(0)
1355
+ save_images_grid(sample, save_path, image_name, rescale=True)
1356
+ logger.info(f"Sample images saved to: {save_path}/{image_name}")
1357
+
1358
+ return f"{save_path}/{image_name}"
1359
+
1360
+
1361
+ def save_output(
1362
+ latent: torch.Tensor, args: argparse.Namespace, cfg, height: int, width: int, original_base_names: Optional[List[str]] = None
1363
+ ) -> None:
1364
+ """save output
1365
+
1366
+ Args:
1367
+ latent: latent tensor
1368
+ args: command line arguments
1369
+ cfg: model configuration
1370
+ height: height of frame
1371
+ width: width of frame
1372
+ original_base_names: original base names (if latents are loaded from files)
1373
+ """
1374
+ if args.output_type == "latent" or args.output_type == "both":
1375
+ # save latent
1376
+ save_latent(latent, args, height, width)
1377
+
1378
+ if args.output_type == "video" or args.output_type == "both":
1379
+ # save video
1380
+ sample = decode_latent(latent.unsqueeze(0), args, cfg)
1381
+ original_name = "" if original_base_names is None else f"_{original_base_names[0]}"
1382
+ save_video(sample, args, original_name)
1383
+
1384
+ elif args.output_type == "images":
1385
+ # save images
1386
+ sample = decode_latent(latent.unsqueeze(0), args, cfg)
1387
+ original_name = "" if original_base_names is None else f"_{original_base_names[0]}"
1388
+ save_images(sample, args, original_name)
1389
+
1390
+
1391
+ def preprocess_prompts_for_batch(prompt_lines: List[str], base_args: argparse.Namespace) -> List[Dict]:
1392
+ """Process multiple prompts for batch mode
1393
+
1394
+ Args:
1395
+ prompt_lines: List of prompt lines
1396
+ base_args: Base command line arguments
1397
+
1398
+ Returns:
1399
+ List[Dict]: List of prompt data dictionaries
1400
+ """
1401
+ prompts_data = []
1402
+
1403
+ for line in prompt_lines:
1404
+ line = line.strip()
1405
+ if not line or line.startswith("#"): # Skip empty lines and comments
1406
+ continue
1407
+
1408
+ # Parse prompt line and create override dictionary
1409
+ prompt_data = parse_prompt_line(line)
1410
+ logger.info(f"Parsed prompt data: {prompt_data}")
1411
+ prompts_data.append(prompt_data)
1412
+
1413
+ return prompts_data
1414
+
1415
+
1416
+ def process_batch_prompts(prompts_data: List[Dict], args: argparse.Namespace) -> None:
1417
+ """Process multiple prompts with model reuse
1418
+
1419
+ Args:
1420
+ prompts_data: List of prompt data dictionaries
1421
+ args: Base command line arguments
1422
+ """
1423
+ if not prompts_data:
1424
+ logger.warning("No valid prompts found")
1425
+ return
1426
+
1427
+ # 1. Load configuration
1428
+ gen_settings = get_generation_settings(args)
1429
+ device, cfg, dit_dtype, dit_weight_dtype, vae_dtype = (
1430
+ gen_settings.device,
1431
+ gen_settings.cfg,
1432
+ gen_settings.dit_dtype,
1433
+ gen_settings.dit_weight_dtype,
1434
+ gen_settings.vae_dtype,
1435
+ )
1436
+ is_i2v = "i2v" in args.task
1437
+
1438
+ # 2. Encode all prompts
1439
+ logger.info("Loading text encoder to encode all prompts")
1440
+ text_encoder = load_text_encoder(args, cfg, device)
1441
+ text_encoder.model.to(device)
1442
+
1443
+ encoded_contexts = {}
1444
+
1445
+ with torch.no_grad():
1446
+ for prompt_data in prompts_data:
1447
+ prompt = prompt_data["prompt"]
1448
+ prompt_args = apply_overrides(args, prompt_data)
1449
+ n_prompt = prompt_data.get(
1450
+ "negative_prompt", prompt_args.negative_prompt if prompt_args.negative_prompt else cfg.sample_neg_prompt
1451
+ )
1452
+
1453
+ if args.fp8_t5:
1454
+ with torch.amp.autocast(device_type=device.type, dtype=cfg.t5_dtype):
1455
+ context = text_encoder([prompt], device)
1456
+ context_null = text_encoder([n_prompt], device)
1457
+ else:
1458
+ context = text_encoder([prompt], device)
1459
+ context_null = text_encoder([n_prompt], device)
1460
+
1461
+ encoded_contexts[prompt] = {"context": context, "context_null": context_null}
1462
+
1463
+ # Free text encoder and clean memory
1464
+ del text_encoder
1465
+ clean_memory_on_device(device)
1466
+
1467
+ # 3. Process I2V additional encodings if needed
1468
+ vae = None
1469
+ if is_i2v:
1470
+ logger.info("Loading VAE and CLIP for I2V preprocessing")
1471
+ vae = load_vae(args, cfg, device, vae_dtype)
1472
+ vae.to_device(device)
1473
+
1474
+ clip = load_clip_model(args, cfg, device)
1475
+ clip.model.to(device)
1476
+
1477
+ # Process each image and encode with CLIP
1478
+ for prompt_data in prompts_data:
1479
+ if "image_path" not in prompt_data:
1480
+ continue
1481
+
1482
+ prompt_args = apply_overrides(args, prompt_data)
1483
+ if not os.path.exists(prompt_args.image_path):
1484
+ logger.warning(f"Image path not found: {prompt_args.image_path}")
1485
+ continue
1486
+
1487
+ # Load and encode image with CLIP
1488
+ img = Image.open(prompt_args.image_path).convert("RGB")
1489
+ img_tensor = TF.to_tensor(img).sub_(0.5).div_(0.5).to(device)
1490
+
1491
+ with torch.amp.autocast(device_type=device.type, dtype=torch.float16), torch.no_grad():
1492
+ clip_context = clip.visual([img_tensor[:, None, :, :]])
1493
+
1494
+ encoded_contexts[prompt_data["prompt"]]["clip_context"] = clip_context
1495
+
1496
+ # Free CLIP and clean memory
1497
+ del clip
1498
+ clean_memory_on_device(device)
1499
+
1500
+ # Keep VAE in CPU memory for later use
1501
+ vae.to_device("cpu")
1502
+ elif cfg.is_fun_control:
1503
+ # For Fun-Control, we need VAE but keep it on CPU
1504
+ vae = load_vae(args, cfg, device, vae_dtype)
1505
+ vae.to_device("cpu")
1506
+
1507
+ # 4. Load DiT model
1508
+ logger.info("Loading DiT model")
1509
+ model = load_dit_model(args, cfg, device, dit_dtype, dit_weight_dtype, is_i2v)
1510
+
1511
+ # 5. Merge LoRA weights if needed
1512
+ if args.lora_weight is not None and len(args.lora_weight) > 0:
1513
+ merge_lora_weights(lora_wan, model, args, device)
1514
+ if args.save_merged_model:
1515
+ logger.info("Model merged and saved. Exiting.")
1516
+ return
1517
+
1518
+ # 6. Optimize model
1519
+ optimize_model(model, args, device, dit_dtype, dit_weight_dtype)
1520
+
1521
+ # Create shared models dict for generate function
1522
+ shared_models = {"vae": vae, "model": model, "encoded_contexts": encoded_contexts}
1523
+
1524
+ # 7. Generate for each prompt
1525
+ all_latents = []
1526
+ all_prompt_args = []
1527
+
1528
+ for i, prompt_data in enumerate(prompts_data):
1529
+ logger.info(f"Processing prompt {i+1}/{len(prompts_data)}: {prompt_data['prompt'][:50]}...")
1530
+
1531
+ # Apply overrides for this prompt
1532
+ prompt_args = apply_overrides(args, prompt_data)
1533
+
1534
+ # Generate latent
1535
+ latent = generate(prompt_args, gen_settings, shared_models)
1536
+
1537
+ # Save latent if needed
1538
+ height, width, _ = check_inputs(prompt_args)
1539
+ if prompt_args.output_type == "latent" or prompt_args.output_type == "both":
1540
+ save_latent(latent, prompt_args, height, width)
1541
+
1542
+ all_latents.append(latent)
1543
+ all_prompt_args.append(prompt_args)
1544
+
1545
+ # 8. Free DiT model
1546
+ del model
1547
+ clean_memory_on_device(device)
1548
+ synchronize_device(device)
1549
+
1550
+ # wait for 5 seconds until block swap is done
1551
+ logger.info("Waiting for 5 seconds to finish block swap")
1552
+ time.sleep(5)
1553
+
1554
+ gc.collect()
1555
+ clean_memory_on_device(device)
1556
+
1557
+ # 9. Decode latents if needed
1558
+ if args.output_type != "latent":
1559
+ logger.info("Decoding latents to videos/images")
1560
+
1561
+ if vae is None:
1562
+ vae = load_vae(args, cfg, device, vae_dtype)
1563
+
1564
+ vae.to_device(device)
1565
+
1566
+ for i, (latent, prompt_args) in enumerate(zip(all_latents, all_prompt_args)):
1567
+ logger.info(f"Decoding output {i+1}/{len(all_latents)}")
1568
+
1569
+ # Decode latent
1570
+ video = decode_latent(latent.unsqueeze(0), prompt_args, cfg)
1571
+
1572
+ # Save as video or images
1573
+ if prompt_args.output_type == "video" or prompt_args.output_type == "both":
1574
+ save_video(video, prompt_args)
1575
+ elif prompt_args.output_type == "images":
1576
+ save_images(video, prompt_args)
1577
+
1578
+ # Free VAE
1579
+ del vae
1580
+
1581
+ clean_memory_on_device(device)
1582
+ gc.collect()
1583
+
1584
+
1585
+ def process_interactive(args: argparse.Namespace) -> None:
1586
+ """Process prompts in interactive mode
1587
+
1588
+ Args:
1589
+ args: Base command line arguments
1590
+ """
1591
+ gen_settings = get_generation_settings(args)
1592
+ device, cfg, dit_dtype, dit_weight_dtype, vae_dtype = (
1593
+ gen_settings.device,
1594
+ gen_settings.cfg,
1595
+ gen_settings.dit_dtype,
1596
+ gen_settings.dit_weight_dtype,
1597
+ gen_settings.vae_dtype,
1598
+ )
1599
+ is_i2v = "i2v" in args.task
1600
+
1601
+ # Initialize models to None
1602
+ text_encoder = None
1603
+ vae = None
1604
+ model = None
1605
+ clip = None
1606
+
1607
+ print("Interactive mode. Enter prompts (Ctrl+D to exit):")
1608
+
1609
+ try:
1610
+ while True:
1611
+ try:
1612
+ line = input("> ")
1613
+ if not line.strip():
1614
+ continue
1615
+
1616
+ # Parse prompt
1617
+ prompt_data = parse_prompt_line(line)
1618
+ prompt_args = apply_overrides(args, prompt_data)
1619
+
1620
+ # Ensure we have all the models we need
1621
+
1622
+ # 1. Load text encoder if not already loaded
1623
+ if text_encoder is None:
1624
+ logger.info("Loading text encoder")
1625
+ text_encoder = load_text_encoder(args, cfg, device)
1626
+
1627
+ text_encoder.model.to(device)
1628
+
1629
+ # Encode prompt
1630
+ n_prompt = prompt_data.get(
1631
+ "negative_prompt", prompt_args.negative_prompt if prompt_args.negative_prompt else cfg.sample_neg_prompt
1632
+ )
1633
+
1634
+ with torch.no_grad():
1635
+ if args.fp8_t5:
1636
+ with torch.amp.autocast(device_type=device.type, dtype=cfg.t5_dtype):
1637
+ context = text_encoder([prompt_data["prompt"]], device)
1638
+ context_null = text_encoder([n_prompt], device)
1639
+ else:
1640
+ context = text_encoder([prompt_data["prompt"]], device)
1641
+ context_null = text_encoder([n_prompt], device)
1642
+
1643
+ encoded_context = {"context": context, "context_null": context_null}
1644
+
1645
+ # Move text encoder to CPU after use
1646
+ text_encoder.model.to("cpu")
1647
+
1648
+ # 2. For I2V, we need CLIP and VAE
1649
+ if is_i2v:
1650
+ if clip is None:
1651
+ logger.info("Loading CLIP model")
1652
+ clip = load_clip_model(args, cfg, device)
1653
+
1654
+ clip.model.to(device)
1655
+
1656
+ # Encode image with CLIP if there's an image path
1657
+ if prompt_args.image_path and os.path.exists(prompt_args.image_path):
1658
+ img = Image.open(prompt_args.image_path).convert("RGB")
1659
+ img_tensor = TF.to_tensor(img).sub_(0.5).div_(0.5).to(device)
1660
+
1661
+ with torch.amp.autocast(device_type=device.type, dtype=torch.float16), torch.no_grad():
1662
+ clip_context = clip.visual([img_tensor[:, None, :, :]])
1663
+
1664
+ encoded_context["clip_context"] = clip_context
1665
+
1666
+ # Move CLIP to CPU after use
1667
+ clip.model.to("cpu")
1668
+
1669
+ # Load VAE if needed
1670
+ if vae is None:
1671
+ logger.info("Loading VAE model")
1672
+ vae = load_vae(args, cfg, device, vae_dtype)
1673
+ elif cfg.is_fun_control and vae is None:
1674
+ # For Fun-Control, we need VAE
1675
+ logger.info("Loading VAE model for Fun-Control")
1676
+ vae = load_vae(args, cfg, device, vae_dtype)
1677
+
1678
+ # 3. Load DiT model if not already loaded
1679
+ if model is None:
1680
+ logger.info("Loading DiT model")
1681
+ model = load_dit_model(args, cfg, device, dit_dtype, dit_weight_dtype, is_i2v)
1682
+
1683
+ # Merge LoRA weights if needed
1684
+ if args.lora_weight is not None and len(args.lora_weight) > 0:
1685
+ merge_lora_weights(lora_wan, model, args, device)
1686
+
1687
+ # Optimize model
1688
+ optimize_model(model, args, device, dit_dtype, dit_weight_dtype)
1689
+ else:
1690
+ # Move model to GPU if it was offloaded
1691
+ model.to(device)
1692
+
1693
+ # Create shared models dict
1694
+ shared_models = {"vae": vae, "model": model, "encoded_contexts": {prompt_data["prompt"]: encoded_context}}
1695
+
1696
+ # Generate latent
1697
+ latent = generate(prompt_args, gen_settings, shared_models)
1698
+
1699
+ # Move model to CPU after generation
1700
+ model.to("cpu")
1701
+
1702
+ # Save latent if needed
1703
+ height, width, _ = check_inputs(prompt_args)
1704
+ if prompt_args.output_type == "latent" or prompt_args.output_type == "both":
1705
+ save_latent(latent, prompt_args, height, width)
1706
+
1707
+ # Decode and save output
1708
+ if prompt_args.output_type != "latent":
1709
+ if vae is None:
1710
+ vae = load_vae(args, cfg, device, vae_dtype)
1711
+
1712
+ vae.to_device(device)
1713
+ video = decode_latent(latent.unsqueeze(0), prompt_args, cfg)
1714
+
1715
+ if prompt_args.output_type == "video" or prompt_args.output_type == "both":
1716
+ save_video(video, prompt_args)
1717
+ elif prompt_args.output_type == "images":
1718
+ save_images(video, prompt_args)
1719
+
1720
+ # Move VAE to CPU after use
1721
+ vae.to_device("cpu")
1722
+
1723
+ clean_memory_on_device(device)
1724
+
1725
+ except KeyboardInterrupt:
1726
+ print("\nInterrupted. Continue (Ctrl+D or Ctrl+Z (Windows) to exit)")
1727
+ continue
1728
+
1729
+ except EOFError:
1730
+ print("\nExiting interactive mode")
1731
+
1732
+ # Clean up all models
1733
+ if text_encoder is not None:
1734
+ del text_encoder
1735
+ if clip is not None:
1736
+ del clip
1737
+ if vae is not None:
1738
+ del vae
1739
+ if model is not None:
1740
+ del model
1741
+
1742
+ clean_memory_on_device(device)
1743
+ gc.collect()
1744
+
1745
+
1746
+ def get_generation_settings(args: argparse.Namespace) -> GenerationSettings:
1747
+ device = torch.device(args.device)
1748
+
1749
+ cfg = WAN_CONFIGS[args.task]
1750
+
1751
+ # select dtype
1752
+ dit_dtype = detect_wan_sd_dtype(args.dit) if args.dit is not None else torch.bfloat16
1753
+ if dit_dtype.itemsize == 1:
1754
+ # if weight is in fp8, use bfloat16 for DiT (input/output)
1755
+ dit_dtype = torch.bfloat16
1756
+ if args.fp8_scaled:
1757
+ raise ValueError(
1758
+ "DiT weights is already in fp8 format, cannot scale to fp8. Please use fp16/bf16 weights / DiTの重みはすでにfp8形式です。fp8にスケーリングできません。fp16/bf16の重みを使用してください"
1759
+ )
1760
+
1761
+ dit_weight_dtype = dit_dtype # default
1762
+ if args.fp8_scaled:
1763
+ dit_weight_dtype = None # various precision weights, so don't cast to specific dtype
1764
+ elif args.fp8:
1765
+ dit_weight_dtype = torch.float8_e4m3fn
1766
+
1767
+ vae_dtype = str_to_dtype(args.vae_dtype) if args.vae_dtype is not None else dit_dtype
1768
+ logger.info(
1769
+ f"Using device: {device}, DiT precision: {dit_dtype}, weight precision: {dit_weight_dtype}, VAE precision: {vae_dtype}"
1770
+ )
1771
+
1772
+ gen_settings = GenerationSettings(
1773
+ device=device,
1774
+ cfg=cfg,
1775
+ dit_dtype=dit_dtype,
1776
+ dit_weight_dtype=dit_weight_dtype,
1777
+ vae_dtype=vae_dtype,
1778
+ )
1779
+ return gen_settings
1780
+
1781
+
1782
+ def main():
1783
+ # Parse arguments
1784
+ args = parse_args()
1785
+
1786
+ # Check if latents are provided
1787
+ latents_mode = args.latent_path is not None and len(args.latent_path) > 0
1788
+
1789
+ # Set device
1790
+ device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
1791
+ device = torch.device(device)
1792
+ logger.info(f"Using device: {device}")
1793
+ args.device = device
1794
+
1795
+ if latents_mode:
1796
+ # Original latent decode mode
1797
+ cfg = WAN_CONFIGS[args.task] # any task is fine
1798
+ original_base_names = []
1799
+ latents_list = []
1800
+ seeds = []
1801
+
1802
+ assert len(args.latent_path) == 1, "Only one latent path is supported for now"
1803
+
1804
+ for latent_path in args.latent_path:
1805
+ original_base_names.append(os.path.splitext(os.path.basename(latent_path))[0])
1806
+ seed = 0
1807
+
1808
+ if os.path.splitext(latent_path)[1] != ".safetensors":
1809
+ latents = torch.load(latent_path, map_location="cpu")
1810
+ else:
1811
+ latents = load_file(latent_path)["latent"]
1812
+ with safe_open(latent_path, framework="pt") as f:
1813
+ metadata = f.metadata()
1814
+ if metadata is None:
1815
+ metadata = {}
1816
+ logger.info(f"Loaded metadata: {metadata}")
1817
+
1818
+ if "seeds" in metadata:
1819
+ seed = int(metadata["seeds"])
1820
+ if "height" in metadata and "width" in metadata:
1821
+ height = int(metadata["height"])
1822
+ width = int(metadata["width"])
1823
+ args.video_size = [height, width]
1824
+ if "video_length" in metadata:
1825
+ args.video_length = int(metadata["video_length"])
1826
+
1827
+ seeds.append(seed)
1828
+ latents_list.append(latents)
1829
+
1830
+ logger.info(f"Loaded latent from {latent_path}. Shape: {latents.shape}")
1831
+
1832
+ latent = torch.stack(latents_list, dim=0) # [N, ...], must be same shape
1833
+
1834
+ height = latents.shape[-2]
1835
+ width = latents.shape[-1]
1836
+ height *= cfg.patch_size[1] * cfg.vae_stride[1]
1837
+ width *= cfg.patch_size[2] * cfg.vae_stride[2]
1838
+ video_length = latents.shape[1]
1839
+ video_length = (video_length - 1) * cfg.vae_stride[0] + 1
1840
+ args.seed = seeds[0]
1841
+
1842
+ # Decode and save
1843
+ save_output(latent[0], args, cfg, height, width, original_base_names)
1844
+
1845
+ elif args.from_file:
1846
+ # Batch mode from file
1847
+ args = setup_args(args)
1848
+
1849
+ # Read prompts from file
1850
+ with open(args.from_file, "r", encoding="utf-8") as f:
1851
+ prompt_lines = f.readlines()
1852
+
1853
+ # Process prompts
1854
+ prompts_data = preprocess_prompts_for_batch(prompt_lines, args)
1855
+ process_batch_prompts(prompts_data, args)
1856
+
1857
+ elif args.interactive:
1858
+ # Interactive mode
1859
+ args = setup_args(args)
1860
+ process_interactive(args)
1861
+
1862
+ else:
1863
+ # Single prompt mode (original behavior)
1864
+ args = setup_args(args)
1865
+ height, width, video_length = check_inputs(args)
1866
+
1867
+ logger.info(
1868
+ f"Video size: {height}x{width}@{video_length} (HxW@F), fps: {args.fps}, "
1869
+ f"infer_steps: {args.infer_steps}, flow_shift: {args.flow_shift}"
1870
+ )
1871
+
1872
+ # Generate latent
1873
+ gen_settings = get_generation_settings(args)
1874
+ latent = generate(args, gen_settings)
1875
+
1876
+ # Make sure the model is freed from GPU memory
1877
+ gc.collect()
1878
+ clean_memory_on_device(args.device)
1879
+
1880
+ # Save latent and video
1881
+ if args.save_merged_model:
1882
+ return
1883
+
1884
+ # Add batch dimension
1885
+ latent = latent.unsqueeze(0)
1886
+ save_output(latent[0], args, WAN_CONFIGS[args.task], height, width)
1887
+
1888
+ logger.info("Done!")
1889
+
1890
+
1891
+ if __name__ == "__main__":
1892
+ main()
blissful_tuner/GIMMVFI.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Frame Rate Interpolation using GIMM-VFI
5
+ -----------------------------------
6
+ This specific code file as well as all files in ./blissful_tuner/gimmvfi and subfolders (all GIMM-VFI related code) licensed:
7
+
8
+ S-Lab License 1.0
9
+ Copyright 2024 S-Lab
10
+
11
+ Redistribution and use for non-commercial purpose in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
12
+ 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
13
+ 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
14
+ 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
15
+
16
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
17
+ IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
18
+ DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
19
+
20
+ In the event that redistribution and/or use for commercial purpose in source or binary forms, with or without modification is required, please contact the contributor(s) of the work.
21
+ ---------------------------------------
22
+ Created on Mon Apr 14 12:23:15 2025
23
+ @author: blyss
24
+ """
25
+
26
+ import os
27
+ import warnings
28
+ from typing import List
29
+ import torch
30
+ import yaml
31
+ from tqdm import tqdm
32
+ from omegaconf import OmegaConf
33
+ from rich.traceback import install as install_rich_tracebacks
34
+
35
+ # Importing necessary modules from our project
36
+ from gimmvfi.generalizable_INR.gimmvfi_r import GIMMVFI_R
37
+ from gimmvfi.generalizable_INR.gimmvfi_f import GIMMVFI_F
38
+ from gimmvfi.generalizable_INR.configs import GIMMVFIConfig
39
+ from gimmvfi.generalizable_INR.raft import RAFT
40
+ from gimmvfi.generalizable_INR.flowformer.core.FlowFormer.LatentCostFormer.transformer import FlowFormer
41
+ from gimmvfi.generalizable_INR.flowformer.configs.submission import get_cfg
42
+ from gimmvfi.utils.utils import InputPadder, RaftArgs, easydict_to_dict
43
+ from utils import load_torch_file, setup_compute_context
44
+ from video_processing_common import BlissfulVideoProcessor, setup_parser_video_common, set_seed
45
+ warnings.filterwarnings("ignore")
46
+ install_rich_tracebacks()
47
+
48
+
49
+ def load_model(model_path: str, device: torch.device, dtype: torch.dtype, mode: str = "gimmvfi_r") -> torch.nn.Module:
50
+ """
51
+ Loads the GIMM-VFI model along with its required flow estimator.
52
+
53
+ Depending on the mode ("gimmvfi_r" or "gimmvfi_f") a different configuration,
54
+ checkpoint, and flow estimation network are loaded.
55
+ """
56
+
57
+ # Select proper configuration, checkpoint, and flow model based on mode.
58
+ if "gimmvfi_r" in mode:
59
+ config_path = os.path.join(model_path, "gimmvfi_r_arb.yaml")
60
+ flow_model_filename = "raft-things_fp32.safetensors"
61
+ checkpoint = os.path.join(model_path, "gimmvfi_r_arb_lpips_fp32.safetensors")
62
+ elif "gimmvfi_f" in mode:
63
+ config_path = os.path.join(model_path, "gimmvfi_f_arb.yaml")
64
+ checkpoint = os.path.join(model_path, "gimmvfi_f_arb_lpips_fp32.safetensors")
65
+ flow_model_filename = "flowformer_sintel_fp32.safetensors"
66
+ else:
67
+ raise ValueError(f"Unsupported mode: {mode}")
68
+
69
+ flow_model_path = os.path.join(model_path, flow_model_filename)
70
+
71
+ # Load and merge YAML configuration
72
+ with open(config_path) as f:
73
+ config = yaml.load(f, Loader=yaml.FullLoader)
74
+ config = easydict_to_dict(config)
75
+ config = OmegaConf.create(config)
76
+ arch_defaults = GIMMVFIConfig.create(config.arch)
77
+ config = OmegaConf.merge(arch_defaults, config.arch)
78
+
79
+ # Initialize the model and its associated flow estimator
80
+ if "gimmvfi_r" in mode:
81
+ model = GIMMVFI_R(config)
82
+ # Setup RAFT as flow estimator
83
+ raft_args = RaftArgs(small=False, mixed_precision=False, alternate_corr=False)
84
+ raft_model = RAFT(raft_args)
85
+ raft_sd = load_torch_file(flow_model_path)
86
+ raft_model.load_state_dict(raft_sd, strict=True)
87
+ flow_estimator = raft_model.to(device, dtype)
88
+ else: # mode == "gimmvfi_f"
89
+ model = GIMMVFI_F(config)
90
+ cfg = get_cfg()
91
+ flowformer = FlowFormer(cfg.latentcostformer)
92
+ flowformer_sd = load_torch_file(flow_model_path)
93
+ flowformer.load_state_dict(flowformer_sd, strict=True)
94
+ flow_estimator = flowformer.to(device, dtype)
95
+
96
+ # Load main model checkpoint
97
+ sd = load_torch_file(checkpoint)
98
+ model.load_state_dict(sd, strict=False)
99
+
100
+ # Attach the flow estimator to the model, set evaluation mode, and move to device
101
+ model.flow_estimator = flow_estimator
102
+ model = model.eval().to(device, dtype)
103
+
104
+ return model
105
+
106
+
107
+ def interpolate(model: torch.nn.Module, frames: List[torch.Tensor], ds_factor: float, N: int, VideoProcessor: BlissfulVideoProcessor):
108
+ """
109
+ Interpolates frames using the provided model.
110
+
111
+ Args:
112
+ model: The loaded interpolation model.
113
+ frames: List of input frame tensors.
114
+ ds_factor: Downsampling factor used by the model.
115
+ N: Number of interpolation steps between two frames.
116
+ """
117
+ device = VideoProcessor.device
118
+ dtype = VideoProcessor.dtype
119
+ start = 0
120
+ end = len(frames) - 1
121
+
122
+ # Process each adjacent pair of frames.
123
+ for j in tqdm(range(start, end), desc="Interpolating frames"):
124
+ I0 = frames[j]
125
+ I2 = frames[j + 1]
126
+
127
+ # For the very first frame, add it directly.
128
+ if j == start:
129
+ VideoProcessor.write_np_or_tensor_to_png(I0)
130
+
131
+ # Pad both images so that their dimensions are divisible by 32.
132
+ padder = InputPadder(I0.shape, 32)
133
+ I0_padded, I2_padded = padder.pad(I0, I2)
134
+ # Concatenate along a new dimension to create a tensor of shape [batch, 2, C, H, W]
135
+ xs = torch.cat((I0_padded.unsqueeze(2), I2_padded.unsqueeze(2)), dim=2).to(device, dtype, non_blocking=True)
136
+
137
+ model.zero_grad()
138
+
139
+ batch_size = xs.shape[0]
140
+ s_shape = xs.shape[-2:]
141
+
142
+ with torch.no_grad():
143
+ # Prepare coordinate inputs and timesteps for interpolation.
144
+ coord_inputs = [
145
+ (
146
+ model.sample_coord_input(
147
+ batch_size,
148
+ s_shape,
149
+ [1 / N * i],
150
+ device=xs.device,
151
+ upsample_ratio=ds_factor,
152
+ ),
153
+ None,
154
+ )
155
+ for i in range(1, N)
156
+ ]
157
+ timesteps = [
158
+ i / N * torch.ones(batch_size, device=xs.device, dtype=dtype)
159
+ for i in range(1, N)
160
+ ]
161
+ if dtype != torch.float32:
162
+ with torch.autocast(device_type=str(device), dtype=dtype):
163
+ all_outputs = model(xs, coord_inputs, t=timesteps, ds_factor=ds_factor)
164
+ else:
165
+ all_outputs = model(xs, coord_inputs, t=timesteps, ds_factor=ds_factor)
166
+ # Unpad the outputs to get back to original image size.
167
+ out_frames = [padder.unpad(im) for im in all_outputs["imgt_pred"]]
168
+
169
+ # Convert each interpolated frame tensor to an image array.
170
+ I1_pred_images = [I1_pred[0] for I1_pred in out_frames]
171
+
172
+ # Append the interpolated frames and corresponding flow images.
173
+ for i in range(N - 1):
174
+ VideoProcessor.write_np_or_tensor_to_png(I1_pred_images[i])
175
+
176
+ # Append the next original frame.
177
+ VideoProcessor.write_np_or_tensor_to_png(I2)
178
+
179
+
180
+ def main():
181
+ parser = setup_parser_video_common(description="Frame rate interpolation using GIMM-VFI")
182
+ parser.add_argument("--ds_factor", type=float, default=1.0, help="Downsampling factor")
183
+ parser.add_argument("--mode", type=str, default="gimmvfi_f", help="Model mode: 'gimmvfi_r' or 'gimmvfi_f' for RAFT or FlowFormer version respectively")
184
+ parser.add_argument(
185
+ "--factor", type=int, default=2, help="Factor to increase the number of frames by. \
186
+ A factor of 2 will double the fps, taking e.g. a 16fps video to 32fps. Can go up to 8 but higher values have more artifacts"
187
+ )
188
+ args = parser.parse_args()
189
+ device, dtype = setup_compute_context(None, args.dtype)
190
+ VideoProcessor = BlissfulVideoProcessor(device, dtype)
191
+ VideoProcessor.prepare_files_and_path(args.input, args.output, "VFI", args.codec, args.container)
192
+ model = load_model(args.model, device, dtype, args.mode)
193
+ frames, fps, _, _ = VideoProcessor.load_frames(make_rgb=True)
194
+ frames = VideoProcessor.np_image_to_tensor(frames)
195
+ new_fps = fps * args.factor # Adjust the frame rate according to the interpolation
196
+
197
+ # Set seed for reproducibility.
198
+ set_seed(args.seed)
199
+
200
+ # Perform the frame interpolation.
201
+ interpolate(model, frames, args.ds_factor, args.factor, VideoProcessor)
202
+
203
+ # Save the interpolated video.
204
+ VideoProcessor.write_buffered_frames_to_output(new_fps, args.keep_pngs)
205
+
206
+
207
+ if __name__ == "__main__":
208
+ main()
blissful_tuner/__init__.py ADDED
File without changes
blissful_tuner/advanced_rope.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on Wed Apr 16 19:25:53 2025
5
+ Advanced rope functions for Blissful Tuner extension
6
+ License: Apache 2.0
7
+
8
+ @author: blyss
9
+ """
10
+ import torch
11
+ import torch.nn as nn
12
+ from einops import rearrange
13
+ from typing import List
14
+ from blissful_tuner.hvw_posemb_layers import get_nd_rotary_pos_embed
15
+
16
+
17
+ # From ComfyUI
18
+ def apply_rope_comfy(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
19
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
20
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
21
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
22
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
23
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
24
+
25
+
26
+ # From WanVideoWrapper
27
+ def rope_riflex(pos, dim, theta, L_test, k, temporal):
28
+ assert dim % 2 == 0
29
+ device = pos.device
30
+ scale = torch.linspace(0, (dim - 2) / dim, steps=dim // 2, dtype=torch.float64, device=device)
31
+ omega = 1.0 / (theta**scale)
32
+ # RIFLEX modification - adjust last frequency component if L_test and k are provided
33
+ if temporal and k > 0 and L_test:
34
+ omega[k - 1] = 0.9 * 2 * torch.pi / L_test
35
+ out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega)
36
+ out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
37
+ out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
38
+ return out.to(dtype=torch.float32, device=pos.device)
39
+
40
+
41
+ class EmbedND_RifleX(nn.Module):
42
+ def __init__(self: nn.Module, dim: int, theta: float, axes_dim: List[int], num_frames: int, k: int):
43
+ super().__init__()
44
+ self.dim = dim
45
+ self.theta = theta
46
+ self.axes_dim = axes_dim
47
+ self.num_frames = num_frames
48
+ self.k = k
49
+
50
+ def forward(self, ids):
51
+ n_axes = ids.shape[-1]
52
+ emb = torch.cat(
53
+ [rope_riflex(ids[..., i], self.axes_dim[i], self.theta, self.num_frames, self.k, temporal=True if i == 0 else False) for i in range(n_axes)],
54
+ dim=-3,
55
+ )
56
+ return emb.unsqueeze(1)
57
+
58
+
59
+ # Modified from HunyuanVideo Wrapper
60
+ def get_rotary_pos_embed_riflex(vae_ver, transformer, latent_video_length, height, width, k=0):
61
+ if "884" in vae_ver:
62
+ latents_size = [(latent_video_length - 1) // 4 + 1, height // 8, width // 8]
63
+ elif "888" in vae_ver:
64
+ latents_size = [(latent_video_length - 1) // 8 + 1, height // 8, width // 8]
65
+ else:
66
+ latents_size = [latent_video_length, height // 8, width // 8]
67
+
68
+ target_ndim = 3
69
+ ndim = 5 - 2
70
+ rope_theta = 256 # 225
71
+ patch_size = transformer.patch_size
72
+ rope_dim_list = transformer.rope_dim_list
73
+ hidden_size = transformer.hidden_size
74
+ heads_num = transformer.heads_num
75
+ head_dim = hidden_size // heads_num
76
+
77
+ if isinstance(patch_size, int):
78
+ assert all(s % patch_size == 0 for s in latents_size), (
79
+ f"Latent size(last {ndim} dimensions) should be divisible by patch size({patch_size}), "
80
+ f"but got {latents_size}."
81
+ )
82
+ rope_sizes = [s // patch_size for s in latents_size]
83
+ elif isinstance(patch_size, list):
84
+ assert all(
85
+ s % patch_size[idx] == 0
86
+ for idx, s in enumerate(latents_size)
87
+ ), (
88
+ f"Latent size(last {ndim} dimensions) should be divisible by patch size({patch_size}), "
89
+ f"but got {latents_size}."
90
+ )
91
+ rope_sizes = [
92
+ s // patch_size[idx] for idx, s in enumerate(latents_size)
93
+ ]
94
+
95
+ if len(rope_sizes) != target_ndim:
96
+ rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis
97
+
98
+ if rope_dim_list is None:
99
+ rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
100
+ assert (
101
+ sum(rope_dim_list) == head_dim
102
+ ), "sum(rope_dim_list) should equal to head_dim of attention layer"
103
+ freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
104
+ rope_dim_list,
105
+ rope_sizes,
106
+ theta=rope_theta,
107
+ use_real=True,
108
+ theta_rescale_factor=1,
109
+ num_frames=latent_video_length,
110
+ k=k,
111
+ )
112
+ return freqs_cos, freqs_sin
blissful_tuner/blissful_args.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on Sat Apr 26 15:11:58 2025
5
+
6
+ @author: blyss
7
+ """
8
+ import sys
9
+ import os
10
+ import argparse
11
+ import torch
12
+ from rich.traceback import install as install_rich_tracebacks
13
+ from blissful_tuner.utils import BlissfulLogger, string_to_seed, parse_scheduled_cfg, error_out
14
+ logger = BlissfulLogger(__name__, "#8e00ed")
15
+
16
+ BLISSFUL_VERSION = "0.4.0"
17
+
18
+ CFG_SCHEDULE_HELP = """
19
+ Comma-separated list of steps/ranges where CFG should be applied.
20
+
21
+ You can specify:
22
+ - Single steps (e.g., '5')
23
+ - Ranges (e.g., '1-10')
24
+ - Modulus patterns (e.g., 'e~2' for every 2 steps)
25
+ - Guidance scale overrides (e.g., '1-10:5.0')
26
+
27
+ Example schedule:
28
+ 'e~2:6.4, 1-10, 46-50'
29
+
30
+ This would apply:
31
+ - Default CFG scale for steps 1-10 and 46-50
32
+ - 6.4 CFG scale every 2 steps outside that range
33
+ - No CFG otherwise
34
+
35
+ You can exclude steps using '!', e.g., '!32' skips step 32.
36
+ Note: The list is processed left to right, so modulus ranges should come first and exclusions at the end!
37
+ """
38
+
39
+ ROOT_SCRIPT = os.path.basename(sys.argv[0]).lower()
40
+ if "hv_" in ROOT_SCRIPT:
41
+ DIFFUSION_MODEL = "hunyuan"
42
+ elif "wan_" in ROOT_SCRIPT:
43
+ DIFFUSION_MODEL = "wan"
44
+ elif "fpack_" in ROOT_SCRIPT:
45
+ DIFFUSION_MODEL = "framepack"
46
+ else:
47
+ raise ValueError("Unsupported root_script for Blissful Extension")
48
+
49
+ if "generate" in ROOT_SCRIPT:
50
+ MODE = "generate"
51
+ elif "train" in ROOT_SCRIPT:
52
+ MODE = "train"
53
+ else:
54
+ raise ValueError("Unsupported root script for Blissful Extension!")
55
+
56
+
57
+ def blissful_prefunc(args: argparse.Namespace):
58
+ """Simple function to print about version, environment, and things"""
59
+ cuda_list = [f"PyTorch: {torch.__version__}"]
60
+ if torch.cuda.is_available():
61
+ allocator = torch.cuda.get_allocator_backend()
62
+ cuda = torch.cuda.get_device_properties(0)
63
+ cuda_list[0] += f", CUDA: {torch.version.cuda} CC: {cuda.major}.{cuda.minor}"
64
+ cuda_list.append(f"Device: '{cuda.name}', VRAM: '{cuda.total_memory // 1024 ** 2}MB'")
65
+ for string in cuda_list:
66
+ logger.info(string)
67
+ if args.fp16_accumulation and MODE == "generate":
68
+ logger.info("Enabling FP16 accumulation")
69
+ if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"):
70
+ torch.backends.cuda.matmul.allow_fp16_accumulation = True
71
+ else:
72
+ raise ValueError("torch.backends.cuda.matmul.allow_fp16_accumulation is not available in this version of torch, requires torch 2.7.0.dev2025 02 26 nightly minimum")
73
+
74
+
75
+ def add_blissful_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
76
+ install_rich_tracebacks()
77
+ if DIFFUSION_MODEL == "wan":
78
+ parser.add_argument("--noise_aug_strength", type=float, default=0.0, help="Additional multiplier for i2v noise, higher might help motion/quality")
79
+ parser.add_argument("--prompt_weighting", action="store_true", help="Enable (prompt weighting:1.2)")
80
+ parser.add_argument(
81
+ "--rope_func", type=str, default="default",
82
+ help="Function to use for ROPE. Choose from 'default' or 'comfy' the latter of which uses ComfyUI implementation and is compilable with torch.compile to enable BIG VRAM savings"
83
+ )
84
+
85
+ elif DIFFUSION_MODEL == "hunyuan":
86
+ parser.add_argument("--hidden_state_skip_layer", type=int, default=2, help="Hidden state skip layer for LLM. Default is 2. Think 'clip skip' for the LLM")
87
+ parser.add_argument("--apply_final_norm", type=bool, default=False, help="Apply final norm for LLM. Default is False. Usually makes things worse.")
88
+ parser.add_argument("--reproduce", action="store_true", help="Enable reproducible output(Same seed = same result. Default is False.")
89
+ parser.add_argument("--fp8_scaled", action="store_true", help="Scaled FP8 quantization. Better quality/accuracy with slightly more VRAM usage.")
90
+ parser.add_argument("--prompt_2", type=str, required=False, help="Optional different prompt for CLIP")
91
+ parser.add_argument("--te_multiplier", nargs=2, metavar=("llm_multiplier", "clip_multiplier"), help="Scale clip and llm influence")
92
+ elif DIFFUSION_MODEL == "framepack":
93
+ parser.add_argument("--preview_latent_every", type=int, default=None, help="Enable latent preview every N sections. If --preview_vae is not specified it will use latent2rgb")
94
+
95
+ if DIFFUSION_MODEL in ["wan", "hunyuan"]:
96
+ parser.add_argument("--riflex_index", type=int, default=0, help="Frequency for RifleX extension. 4 is good for Hunyuan, 6 is good for Wan. Only 'comfy' rope_func supports this with Wan!")
97
+ parser.add_argument("--cfgzerostar_scaling", action="store_true", help="Enables CFG-Zero* scaling - https://github.com/WeichenFan/CFG-Zero-star")
98
+ parser.add_argument("--cfgzerostar_init_steps", type=int, default=-1, help="Enables CFGZero* zeroing out the first N steps. 2 is good for Wan T2V, 1 for I2V")
99
+ parser.add_argument("--preview_latent_every", type=int, default=None, help="Enable latent preview every N steps. If --preview_vae is not specified it will use latent2rgb")
100
+
101
+ # Common
102
+
103
+ parser.add_argument("--preview_vae", type=str, help="Path to TAE vae for taehv previews")
104
+ parser.add_argument("--cfg_schedule", type=str, help=CFG_SCHEDULE_HELP)
105
+ parser.add_argument("--keep_pngs", action="store_true", help="Save frames as PNGs in addition to output video")
106
+ parser.add_argument("--codec", choices=["prores", "h264", "h265"], default=None, help="Codec to use, choose from 'prores', 'h264', or 'h265'")
107
+ parser.add_argument("--container", choices=["mkv", "mp4"], default="mkv", help="Container format to use, choose from 'mkv' or 'mp4'. Note prores can only go in MKV!")
108
+ parser.add_argument("--fp16_accumulation", action="store_true", help="Enable full FP16 Accmumulation in FP16 GEMMs, requires Pytorch 2.7.0 or higher")
109
+ return parser
110
+
111
+
112
+ def parse_blissful_args(args: argparse.Namespace) -> argparse.Namespace:
113
+ if args.seed is not None:
114
+ try:
115
+ args.seed = int(args.seed)
116
+ except ValueError:
117
+ string_seed = args.seed
118
+ args.seed = string_to_seed(args.seed)
119
+ logger.info(f"Seed {args.seed} was generated from string '{string_seed}'!")
120
+ if DIFFUSION_MODEL == "wan":
121
+ if args.riflex_index != 0 and args.rope_func.lower() != "comfy":
122
+ logger.error("RIFLEx can only be used with rope_func == 'comfy'!")
123
+ raise ValueError("RIFLEx can only be used with rope_func =='comfy'!")
124
+ if DIFFUSION_MODEL in ["wan", "hunyuan"]:
125
+ if args.cfg_schedule:
126
+ args.cfg_schedule = parse_scheduled_cfg(args.cfg_schedule, args.infer_steps, args.guidance_scale)
127
+ if args.cfgzerostar_scaling or args.cfgzerostar_init_steps != -1:
128
+ if args.guidance_scale == 1 and not args.cfg_schedule:
129
+ error_out(AttributeError, "Requested CFGZero* but CFG is not enabled so it will have no effect!")
130
+ blissful_prefunc(args)
131
+ return args
blissful_tuner/blissful_settings.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on Tue Mar 11 19:08:55 2025
5
+
6
+ @author: blyss
7
+ """
8
+ import os
9
+ import json
10
+
11
+
12
+ class SingletonMeta(type):
13
+ """
14
+ The SingletonMeta class is useful for creating objects that persist as a single instance across the whole program. Basically a global class.
15
+ """
16
+ _instances = {}
17
+
18
+ def __call__(cls, *Parameters, **kwParameters):
19
+ if cls not in cls._instances:
20
+ cls._instances[cls] = super(SingletonMeta, cls).__call__(*Parameters, **kwParameters)
21
+ return cls._instances[cls]
22
+
23
+
24
+ class BlissfulSettings(metaclass=SingletonMeta):
25
+ def __init__(self):
26
+ """
27
+ Loads the settings from a 'settings.json' file, creating it with default settings if it doesn't exist.
28
+
29
+ This method attempts to read the program's settings from a JSON file. If the file does not exist,
30
+ it creates a new file with default settings. This ensures that the program can start with a known
31
+ set of configurations and modify them as needed.
32
+
33
+ This class is a SingletonMeta so even if we reinstantiate the class, this only happens the first time
34
+ """
35
+ # These are globals that do not persist
36
+ self.generating = 0
37
+ self.last_preview_file = ""
38
+
39
+ default_settings = {
40
+ "prompt": "a cat walks on the grass, realistic style",
41
+ "resolution_x": 960,
42
+ "resolution_y": 544,
43
+ "fps": 24,
44
+ "embedded_guidance": 6.0,
45
+ "flow_shift": 7.0,
46
+ "infer_steps": 50,
47
+ "seed": 42,
48
+ "video_length": 129,
49
+ "attention": "sage",
50
+ "blocks_to_swap": 0,
51
+ "hidden_state_skip_layer": 2,
52
+ "apply_final_norm": False,
53
+ "reproduce": False,
54
+ "fp8": True,
55
+ "fp8_fast": False,
56
+ "do_compile": False,
57
+ "transformer_path": "",
58
+ "text_encoder_1_path": "",
59
+ "text_encoder_2_path": "",
60
+ "vae_path": "",
61
+ "lora_path": "",
62
+ }
63
+
64
+ if not os.path.exists("./settings.json"):
65
+ with open("./settings.json", "w", encoding="utf-8") as file:
66
+ json.dump(default_settings, file, indent=4)
67
+ print("No existing settings found. Created default settings file.")
68
+
69
+ with open("./settings.json", "r", encoding="utf-8") as file:
70
+ data = json.load(file)
71
+
72
+ for key, default_value in default_settings.items():
73
+ setattr(self, key, data.get(key, default_value))
74
+
75
+ def save_to_file(self):
76
+ """
77
+ Saves the current settings to a JSON file named 'settings.json'.
78
+ """
79
+ settings = {
80
+ "prompt": self.prompt,
81
+ "resolution_x": self.resolution_x,
82
+ "resolution_y": self.resolution_y,
83
+ "fps": self.fps,
84
+ "embedded_guidance": self.embedded_guidance,
85
+ "flow_shift": self.flow_shift,
86
+ "infer_steps": self.infer_steps,
87
+ "seed": self.seed,
88
+ "video_length": self.video_length,
89
+ "attention": self.attention,
90
+ "blocks_to_swap": self.blocks_to_swap,
91
+ "hidden_state_skip_layer": self.hidden_state_skip_layer,
92
+ "apply_final_norm": self.apply_final_norm,
93
+ "reproduce": self.reproduce,
94
+ "fp8": self.fp8,
95
+ "fp8_fast": self.fp8_fast,
96
+ "do_compile": self.do_compile,
97
+ "transformer_path": self.transformer_path,
98
+ "text_encoder_1_path": self.text_encoder_1_path,
99
+ "text_encoder_2_path": self.text_encoder_2_path,
100
+ "vae_path": self.vae_path,
101
+ "lora_path": self.lora_path,
102
+ }
103
+
104
+ with open("./settings.json", "w", encoding="utf-8") as file:
105
+ json.dump(settings, file, indent=4)
106
+
107
+ def update(self, option, value, label_target=None, label_value=None):
108
+ """Method for updating various settings called via QT connection and may update an associated label/value"""
109
+ setattr(self, option, value)
110
+ if label_target is not None and label_value is not None:
111
+ label_target.setText(str(label_value))
blissful_tuner/cfgzerostar.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on Wed Apr 16 17:23:28 2025
5
+ CFGZero* implementation for Blissful Tuner extension based on https://github.com/WeichenFan/CFG-Zero-star/blob/main/models/wan/wan_pipeline.py
6
+ License: Apache 2.0
7
+ @author: blyss
8
+ """
9
+ import torch
10
+
11
+
12
+ def apply_zerostar(cond: torch.Tensor, uncond: torch.Tensor, current_step: int, guidance_scale: float, use_scaling: bool = True, zero_init_steps: int = -1) -> torch.Tensor:
13
+
14
+ if (current_step <= zero_init_steps):
15
+ return cond * 0
16
+ if not use_scaling:
17
+ # CFG formula
18
+ noise_pred = uncond + guidance_scale * (cond - uncond)
19
+ else:
20
+ batch_size = cond.shape[0]
21
+ positive_flat = cond.view(batch_size, -1)
22
+ negative_flat = uncond.view(batch_size, -1)
23
+ alpha = optimized_scale(positive_flat, negative_flat)
24
+ alpha = alpha.view(batch_size, *([1] * (len(cond.shape) - 1)))
25
+ alpha = alpha.to(cond.dtype)
26
+ # CFG formula modified with alpha
27
+ noise_pred = uncond * alpha + guidance_scale * (cond - uncond * alpha)
28
+ return noise_pred
29
+
30
+
31
+ def optimized_scale(positive_flat, negative_flat):
32
+
33
+ dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
34
+ squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
35
+
36
+ # st_star = v_cond^T * v_uncond / ||v_uncond||^2
37
+ st_star = dot_product / squared_norm
38
+
39
+ return st_star
blissful_tuner/codeformer/LICENSE ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ THIS FOLDER AND SUBFOLDERS (all CodeFormer related code and files) LICENSED AS BELOW
2
+
3
+ S-Lab License 1.0
4
+ Copyright 2024 S-Lab
5
+
6
+ Redistribution and use for non-commercial purpose in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
7
+ 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
8
+ 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
9
+ 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
10
+
11
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
12
+ IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
13
+ DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
14
+
15
+ In the event that redistribution and/or use for commercial purpose in source or binary forms, with or without modification is required, please contact the contributor(s) of the work.
blissful_tuner/codeformer/basicsr/VERSION ADDED
@@ -0,0 +1 @@
 
 
1
+ 1.3.2
blissful_tuner/codeformer/basicsr/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/xinntao/BasicSR
2
+ # flake8: noqa
3
+ from .archs import *
4
+ from .data import *
5
+ from .losses import *
6
+ from .metrics import *
7
+ from .models import *
8
+ from .ops import *
9
+ from .train import *
10
+ from .utils import *
11
+ #from .version import __gitsha__, __version__
blissful_tuner/codeformer/basicsr/archs/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from copy import deepcopy
3
+ from os import path as osp
4
+
5
+ from codeformer.basicsr.utils import get_root_logger, scandir
6
+ from codeformer.basicsr.utils.registry import ARCH_REGISTRY
7
+
8
+ __all__ = ['build_network']
9
+
10
+ # automatically scan and import arch modules for registry
11
+ # scan all the files under the 'archs' folder and collect files ending with
12
+ # '_arch.py'
13
+ arch_folder = osp.dirname(osp.abspath(__file__))
14
+ arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
15
+ # import all the arch modules
16
+ _arch_modules = [importlib.import_module(f'codeformer.basicsr.archs.{file_name}') for file_name in arch_filenames]
17
+
18
+
19
+ def build_network(opt):
20
+ opt = deepcopy(opt)
21
+ network_type = opt.pop('type')
22
+ net = ARCH_REGISTRY.get(network_type)(**opt)
23
+ logger = get_root_logger()
24
+ logger.info(f'Network [{net.__class__.__name__}] is created.')
25
+ return net
blissful_tuner/codeformer/basicsr/archs/arcface_arch.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from codeformer.basicsr.utils.registry import ARCH_REGISTRY
3
+
4
+
5
+ def conv3x3(inplanes, outplanes, stride=1):
6
+ """A simple wrapper for 3x3 convolution with padding.
7
+
8
+ Args:
9
+ inplanes (int): Channel number of inputs.
10
+ outplanes (int): Channel number of outputs.
11
+ stride (int): Stride in convolution. Default: 1.
12
+ """
13
+ return nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False)
14
+
15
+
16
+ class BasicBlock(nn.Module):
17
+ """Basic residual block used in the ResNetArcFace architecture.
18
+
19
+ Args:
20
+ inplanes (int): Channel number of inputs.
21
+ planes (int): Channel number of outputs.
22
+ stride (int): Stride in convolution. Default: 1.
23
+ downsample (nn.Module): The downsample module. Default: None.
24
+ """
25
+ expansion = 1 # output channel expansion ratio
26
+
27
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
28
+ super(BasicBlock, self).__init__()
29
+ self.conv1 = conv3x3(inplanes, planes, stride)
30
+ self.bn1 = nn.BatchNorm2d(planes)
31
+ self.relu = nn.ReLU(inplace=True)
32
+ self.conv2 = conv3x3(planes, planes)
33
+ self.bn2 = nn.BatchNorm2d(planes)
34
+ self.downsample = downsample
35
+ self.stride = stride
36
+
37
+ def forward(self, x):
38
+ residual = x
39
+
40
+ out = self.conv1(x)
41
+ out = self.bn1(out)
42
+ out = self.relu(out)
43
+
44
+ out = self.conv2(out)
45
+ out = self.bn2(out)
46
+
47
+ if self.downsample is not None:
48
+ residual = self.downsample(x)
49
+
50
+ out += residual
51
+ out = self.relu(out)
52
+
53
+ return out
54
+
55
+
56
+ class IRBlock(nn.Module):
57
+ """Improved residual block (IR Block) used in the ResNetArcFace architecture.
58
+
59
+ Args:
60
+ inplanes (int): Channel number of inputs.
61
+ planes (int): Channel number of outputs.
62
+ stride (int): Stride in convolution. Default: 1.
63
+ downsample (nn.Module): The downsample module. Default: None.
64
+ use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
65
+ """
66
+ expansion = 1 # output channel expansion ratio
67
+
68
+ def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
69
+ super(IRBlock, self).__init__()
70
+ self.bn0 = nn.BatchNorm2d(inplanes)
71
+ self.conv1 = conv3x3(inplanes, inplanes)
72
+ self.bn1 = nn.BatchNorm2d(inplanes)
73
+ self.prelu = nn.PReLU()
74
+ self.conv2 = conv3x3(inplanes, planes, stride)
75
+ self.bn2 = nn.BatchNorm2d(planes)
76
+ self.downsample = downsample
77
+ self.stride = stride
78
+ self.use_se = use_se
79
+ if self.use_se:
80
+ self.se = SEBlock(planes)
81
+
82
+ def forward(self, x):
83
+ residual = x
84
+ out = self.bn0(x)
85
+ out = self.conv1(out)
86
+ out = self.bn1(out)
87
+ out = self.prelu(out)
88
+
89
+ out = self.conv2(out)
90
+ out = self.bn2(out)
91
+ if self.use_se:
92
+ out = self.se(out)
93
+
94
+ if self.downsample is not None:
95
+ residual = self.downsample(x)
96
+
97
+ out += residual
98
+ out = self.prelu(out)
99
+
100
+ return out
101
+
102
+
103
+ class Bottleneck(nn.Module):
104
+ """Bottleneck block used in the ResNetArcFace architecture.
105
+
106
+ Args:
107
+ inplanes (int): Channel number of inputs.
108
+ planes (int): Channel number of outputs.
109
+ stride (int): Stride in convolution. Default: 1.
110
+ downsample (nn.Module): The downsample module. Default: None.
111
+ """
112
+ expansion = 4 # output channel expansion ratio
113
+
114
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
115
+ super(Bottleneck, self).__init__()
116
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
117
+ self.bn1 = nn.BatchNorm2d(planes)
118
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
119
+ self.bn2 = nn.BatchNorm2d(planes)
120
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
121
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
122
+ self.relu = nn.ReLU(inplace=True)
123
+ self.downsample = downsample
124
+ self.stride = stride
125
+
126
+ def forward(self, x):
127
+ residual = x
128
+
129
+ out = self.conv1(x)
130
+ out = self.bn1(out)
131
+ out = self.relu(out)
132
+
133
+ out = self.conv2(out)
134
+ out = self.bn2(out)
135
+ out = self.relu(out)
136
+
137
+ out = self.conv3(out)
138
+ out = self.bn3(out)
139
+
140
+ if self.downsample is not None:
141
+ residual = self.downsample(x)
142
+
143
+ out += residual
144
+ out = self.relu(out)
145
+
146
+ return out
147
+
148
+
149
+ class SEBlock(nn.Module):
150
+ """The squeeze-and-excitation block (SEBlock) used in the IRBlock.
151
+
152
+ Args:
153
+ channel (int): Channel number of inputs.
154
+ reduction (int): Channel reduction ration. Default: 16.
155
+ """
156
+
157
+ def __init__(self, channel, reduction=16):
158
+ super(SEBlock, self).__init__()
159
+ self.avg_pool = nn.AdaptiveAvgPool2d(1) # pool to 1x1 without spatial information
160
+ self.fc = nn.Sequential(
161
+ nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel),
162
+ nn.Sigmoid())
163
+
164
+ def forward(self, x):
165
+ b, c, _, _ = x.size()
166
+ y = self.avg_pool(x).view(b, c)
167
+ y = self.fc(y).view(b, c, 1, 1)
168
+ return x * y
169
+
170
+
171
+ @ARCH_REGISTRY.register()
172
+ class ResNetArcFace(nn.Module):
173
+ """ArcFace with ResNet architectures.
174
+
175
+ Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition.
176
+
177
+ Args:
178
+ block (str): Block used in the ArcFace architecture.
179
+ layers (tuple(int)): Block numbers in each layer.
180
+ use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
181
+ """
182
+
183
+ def __init__(self, block, layers, use_se=True):
184
+ if block == 'IRBlock':
185
+ block = IRBlock
186
+ self.inplanes = 64
187
+ self.use_se = use_se
188
+ super(ResNetArcFace, self).__init__()
189
+
190
+ self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
191
+ self.bn1 = nn.BatchNorm2d(64)
192
+ self.prelu = nn.PReLU()
193
+ self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
194
+ self.layer1 = self._make_layer(block, 64, layers[0])
195
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
196
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
197
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
198
+ self.bn4 = nn.BatchNorm2d(512)
199
+ self.dropout = nn.Dropout()
200
+ self.fc5 = nn.Linear(512 * 8 * 8, 512)
201
+ self.bn5 = nn.BatchNorm1d(512)
202
+
203
+ # initialization
204
+ for m in self.modules():
205
+ if isinstance(m, nn.Conv2d):
206
+ nn.init.xavier_normal_(m.weight)
207
+ elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
208
+ nn.init.constant_(m.weight, 1)
209
+ nn.init.constant_(m.bias, 0)
210
+ elif isinstance(m, nn.Linear):
211
+ nn.init.xavier_normal_(m.weight)
212
+ nn.init.constant_(m.bias, 0)
213
+
214
+ def _make_layer(self, block, planes, num_blocks, stride=1):
215
+ downsample = None
216
+ if stride != 1 or self.inplanes != planes * block.expansion:
217
+ downsample = nn.Sequential(
218
+ nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
219
+ nn.BatchNorm2d(planes * block.expansion),
220
+ )
221
+ layers = []
222
+ layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se))
223
+ self.inplanes = planes
224
+ for _ in range(1, num_blocks):
225
+ layers.append(block(self.inplanes, planes, use_se=self.use_se))
226
+
227
+ return nn.Sequential(*layers)
228
+
229
+ def forward(self, x):
230
+ x = self.conv1(x)
231
+ x = self.bn1(x)
232
+ x = self.prelu(x)
233
+ x = self.maxpool(x)
234
+
235
+ x = self.layer1(x)
236
+ x = self.layer2(x)
237
+ x = self.layer3(x)
238
+ x = self.layer4(x)
239
+ x = self.bn4(x)
240
+ x = self.dropout(x)
241
+ x = x.view(x.size(0), -1)
242
+ x = self.fc5(x)
243
+ x = self.bn5(x)
244
+
245
+ return x
blissful_tuner/codeformer/basicsr/archs/arch_util.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections.abc
2
+ import math
3
+ import torch
4
+ import torchvision
5
+ import warnings
6
+ from distutils.version import LooseVersion
7
+ from itertools import repeat
8
+ from torch import nn as nn
9
+ from torch.nn import functional as F
10
+ from torch.nn import init as init
11
+ from torch.nn.modules.batchnorm import _BatchNorm
12
+
13
+ from codeformer.basicsr.ops.dcn import ModulatedDeformConvPack, modulated_deform_conv
14
+ from codeformer.basicsr.utils import get_root_logger
15
+
16
+
17
+ @torch.no_grad()
18
+ def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
19
+ """Initialize network weights.
20
+
21
+ Args:
22
+ module_list (list[nn.Module] | nn.Module): Modules to be initialized.
23
+ scale (float): Scale initialized weights, especially for residual
24
+ blocks. Default: 1.
25
+ bias_fill (float): The value to fill bias. Default: 0
26
+ kwargs (dict): Other arguments for initialization function.
27
+ """
28
+ if not isinstance(module_list, list):
29
+ module_list = [module_list]
30
+ for module in module_list:
31
+ for m in module.modules():
32
+ if isinstance(m, nn.Conv2d):
33
+ init.kaiming_normal_(m.weight, **kwargs)
34
+ m.weight.data *= scale
35
+ if m.bias is not None:
36
+ m.bias.data.fill_(bias_fill)
37
+ elif isinstance(m, nn.Linear):
38
+ init.kaiming_normal_(m.weight, **kwargs)
39
+ m.weight.data *= scale
40
+ if m.bias is not None:
41
+ m.bias.data.fill_(bias_fill)
42
+ elif isinstance(m, _BatchNorm):
43
+ init.constant_(m.weight, 1)
44
+ if m.bias is not None:
45
+ m.bias.data.fill_(bias_fill)
46
+
47
+
48
+ def make_layer(basic_block, num_basic_block, **kwarg):
49
+ """Make layers by stacking the same blocks.
50
+
51
+ Args:
52
+ basic_block (nn.module): nn.module class for basic block.
53
+ num_basic_block (int): number of blocks.
54
+
55
+ Returns:
56
+ nn.Sequential: Stacked blocks in nn.Sequential.
57
+ """
58
+ layers = []
59
+ for _ in range(num_basic_block):
60
+ layers.append(basic_block(**kwarg))
61
+ return nn.Sequential(*layers)
62
+
63
+
64
+ class ResidualBlockNoBN(nn.Module):
65
+ """Residual block without BN.
66
+
67
+ It has a style of:
68
+ ---Conv-ReLU-Conv-+-
69
+ |________________|
70
+
71
+ Args:
72
+ num_feat (int): Channel number of intermediate features.
73
+ Default: 64.
74
+ res_scale (float): Residual scale. Default: 1.
75
+ pytorch_init (bool): If set to True, use pytorch default init,
76
+ otherwise, use default_init_weights. Default: False.
77
+ """
78
+
79
+ def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
80
+ super(ResidualBlockNoBN, self).__init__()
81
+ self.res_scale = res_scale
82
+ self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
83
+ self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
84
+ self.relu = nn.ReLU(inplace=True)
85
+
86
+ if not pytorch_init:
87
+ default_init_weights([self.conv1, self.conv2], 0.1)
88
+
89
+ def forward(self, x):
90
+ identity = x
91
+ out = self.conv2(self.relu(self.conv1(x)))
92
+ return identity + out * self.res_scale
93
+
94
+
95
+ class Upsample(nn.Sequential):
96
+ """Upsample module.
97
+
98
+ Args:
99
+ scale (int): Scale factor. Supported scales: 2^n and 3.
100
+ num_feat (int): Channel number of intermediate features.
101
+ """
102
+
103
+ def __init__(self, scale, num_feat):
104
+ m = []
105
+ if (scale & (scale - 1)) == 0: # scale = 2^n
106
+ for _ in range(int(math.log(scale, 2))):
107
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
108
+ m.append(nn.PixelShuffle(2))
109
+ elif scale == 3:
110
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
111
+ m.append(nn.PixelShuffle(3))
112
+ else:
113
+ raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
114
+ super(Upsample, self).__init__(*m)
115
+
116
+
117
+ def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True):
118
+ """Warp an image or feature map with optical flow.
119
+
120
+ Args:
121
+ x (Tensor): Tensor with size (n, c, h, w).
122
+ flow (Tensor): Tensor with size (n, h, w, 2), normal value.
123
+ interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
124
+ padding_mode (str): 'zeros' or 'border' or 'reflection'.
125
+ Default: 'zeros'.
126
+ align_corners (bool): Before pytorch 1.3, the default value is
127
+ align_corners=True. After pytorch 1.3, the default value is
128
+ align_corners=False. Here, we use the True as default.
129
+
130
+ Returns:
131
+ Tensor: Warped image or feature map.
132
+ """
133
+ assert x.size()[-2:] == flow.size()[1:3]
134
+ _, _, h, w = x.size()
135
+ # create mesh grid
136
+ grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x))
137
+ grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
138
+ grid.requires_grad = False
139
+
140
+ vgrid = grid + flow
141
+ # scale grid to [-1,1]
142
+ vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
143
+ vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
144
+ vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
145
+ output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners)
146
+
147
+ # TODO, what if align_corners=False
148
+ return output
149
+
150
+
151
+ def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False):
152
+ """Resize a flow according to ratio or shape.
153
+
154
+ Args:
155
+ flow (Tensor): Precomputed flow. shape [N, 2, H, W].
156
+ size_type (str): 'ratio' or 'shape'.
157
+ sizes (list[int | float]): the ratio for resizing or the final output
158
+ shape.
159
+ 1) The order of ratio should be [ratio_h, ratio_w]. For
160
+ downsampling, the ratio should be smaller than 1.0 (i.e., ratio
161
+ < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
162
+ ratio > 1.0).
163
+ 2) The order of output_size should be [out_h, out_w].
164
+ interp_mode (str): The mode of interpolation for resizing.
165
+ Default: 'bilinear'.
166
+ align_corners (bool): Whether align corners. Default: False.
167
+
168
+ Returns:
169
+ Tensor: Resized flow.
170
+ """
171
+ _, _, flow_h, flow_w = flow.size()
172
+ if size_type == 'ratio':
173
+ output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
174
+ elif size_type == 'shape':
175
+ output_h, output_w = sizes[0], sizes[1]
176
+ else:
177
+ raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.')
178
+
179
+ input_flow = flow.clone()
180
+ ratio_h = output_h / flow_h
181
+ ratio_w = output_w / flow_w
182
+ input_flow[:, 0, :, :] *= ratio_w
183
+ input_flow[:, 1, :, :] *= ratio_h
184
+ resized_flow = F.interpolate(
185
+ input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners)
186
+ return resized_flow
187
+
188
+
189
+ # TODO: may write a cpp file
190
+ def pixel_unshuffle(x, scale):
191
+ """ Pixel unshuffle.
192
+
193
+ Args:
194
+ x (Tensor): Input feature with shape (b, c, hh, hw).
195
+ scale (int): Downsample ratio.
196
+
197
+ Returns:
198
+ Tensor: the pixel unshuffled feature.
199
+ """
200
+ b, c, hh, hw = x.size()
201
+ out_channel = c * (scale**2)
202
+ assert hh % scale == 0 and hw % scale == 0
203
+ h = hh // scale
204
+ w = hw // scale
205
+ x_view = x.view(b, c, h, scale, w, scale)
206
+ return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
207
+
208
+
209
+ class DCNv2Pack(ModulatedDeformConvPack):
210
+ """Modulated deformable conv for deformable alignment.
211
+
212
+ Different from the official DCNv2Pack, which generates offsets and masks
213
+ from the preceding features, this DCNv2Pack takes another different
214
+ features to generate offsets and masks.
215
+
216
+ Ref:
217
+ Delving Deep into Deformable Alignment in Video Super-Resolution.
218
+ """
219
+
220
+ def forward(self, x, feat):
221
+ out = self.conv_offset(feat)
222
+ o1, o2, mask = torch.chunk(out, 3, dim=1)
223
+ offset = torch.cat((o1, o2), dim=1)
224
+ mask = torch.sigmoid(mask)
225
+
226
+ offset_absmean = torch.mean(torch.abs(offset))
227
+ if offset_absmean > 50:
228
+ logger = get_root_logger()
229
+ logger.warning(f'Offset abs mean is {offset_absmean}, larger than 50.')
230
+
231
+ if LooseVersion(torchvision.__version__) >= LooseVersion('0.9.0'):
232
+ return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding,
233
+ self.dilation, mask)
234
+ else:
235
+ return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding,
236
+ self.dilation, self.groups, self.deformable_groups)
237
+
238
+
239
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
240
+ # From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
241
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
242
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
243
+ def norm_cdf(x):
244
+ # Computes standard normal cumulative distribution function
245
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
246
+
247
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
248
+ warnings.warn(
249
+ 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
250
+ 'The distribution of values may be incorrect.',
251
+ stacklevel=2)
252
+
253
+ with torch.no_grad():
254
+ # Values are generated by using a truncated uniform distribution and
255
+ # then using the inverse CDF for the normal distribution.
256
+ # Get upper and lower cdf values
257
+ low = norm_cdf((a - mean) / std)
258
+ up = norm_cdf((b - mean) / std)
259
+
260
+ # Uniformly fill tensor with values from [low, up], then translate to
261
+ # [2l-1, 2u-1].
262
+ tensor.uniform_(2 * low - 1, 2 * up - 1)
263
+
264
+ # Use inverse cdf transform for normal distribution to get truncated
265
+ # standard normal
266
+ tensor.erfinv_()
267
+
268
+ # Transform to proper mean, std
269
+ tensor.mul_(std * math.sqrt(2.))
270
+ tensor.add_(mean)
271
+
272
+ # Clamp to ensure it's in the proper range
273
+ tensor.clamp_(min=a, max=b)
274
+ return tensor
275
+
276
+
277
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
278
+ r"""Fills the input Tensor with values drawn from a truncated
279
+ normal distribution.
280
+
281
+ From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
282
+
283
+ The values are effectively drawn from the
284
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
285
+ with values outside :math:`[a, b]` redrawn until they are within
286
+ the bounds. The method used for generating the random values works
287
+ best when :math:`a \leq \text{mean} \leq b`.
288
+
289
+ Args:
290
+ tensor: an n-dimensional `torch.Tensor`
291
+ mean: the mean of the normal distribution
292
+ std: the standard deviation of the normal distribution
293
+ a: the minimum cutoff value
294
+ b: the maximum cutoff value
295
+
296
+ Examples:
297
+ >>> w = torch.empty(3, 5)
298
+ >>> nn.init.trunc_normal_(w)
299
+ """
300
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
301
+
302
+
303
+ # From PyTorch
304
+ def _ntuple(n):
305
+
306
+ def parse(x):
307
+ if isinstance(x, collections.abc.Iterable):
308
+ return x
309
+ return tuple(repeat(x, n))
310
+
311
+ return parse
312
+
313
+
314
+ to_1tuple = _ntuple(1)
315
+ to_2tuple = _ntuple(2)
316
+ to_3tuple = _ntuple(3)
317
+ to_4tuple = _ntuple(4)
318
+ to_ntuple = _ntuple
blissful_tuner/codeformer/basicsr/archs/codeformer_arch.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn, Tensor
5
+ import torch.nn.functional as F
6
+ from typing import Optional, List
7
+
8
+ from codeformer.basicsr.archs.vqgan_arch import *
9
+ from codeformer.basicsr.utils import get_root_logger
10
+ from codeformer.basicsr.utils.registry import ARCH_REGISTRY
11
+
12
+ def calc_mean_std(feat, eps=1e-5):
13
+ """Calculate mean and std for adaptive_instance_normalization.
14
+
15
+ Args:
16
+ feat (Tensor): 4D tensor.
17
+ eps (float): A small value added to the variance to avoid
18
+ divide-by-zero. Default: 1e-5.
19
+ """
20
+ size = feat.size()
21
+ assert len(size) == 4, 'The input feature should be 4D tensor.'
22
+ b, c = size[:2]
23
+ feat_var = feat.view(b, c, -1).var(dim=2) + eps
24
+ feat_std = feat_var.sqrt().view(b, c, 1, 1)
25
+ feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
26
+ return feat_mean, feat_std
27
+
28
+
29
+ def adaptive_instance_normalization(content_feat, style_feat):
30
+ """Adaptive instance normalization.
31
+
32
+ Adjust the reference features to have the similar color and illuminations
33
+ as those in the degradate features.
34
+
35
+ Args:
36
+ content_feat (Tensor): The reference feature.
37
+ style_feat (Tensor): The degradate features.
38
+ """
39
+ size = content_feat.size()
40
+ style_mean, style_std = calc_mean_std(style_feat)
41
+ content_mean, content_std = calc_mean_std(content_feat)
42
+ normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
43
+ return normalized_feat * style_std.expand(size) + style_mean.expand(size)
44
+
45
+
46
+ class PositionEmbeddingSine(nn.Module):
47
+ """
48
+ This is a more standard version of the position embedding, very similar to the one
49
+ used by the Attention is all you need paper, generalized to work on images.
50
+ """
51
+
52
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
53
+ super().__init__()
54
+ self.num_pos_feats = num_pos_feats
55
+ self.temperature = temperature
56
+ self.normalize = normalize
57
+ if scale is not None and normalize is False:
58
+ raise ValueError("normalize should be True if scale is passed")
59
+ if scale is None:
60
+ scale = 2 * math.pi
61
+ self.scale = scale
62
+
63
+ def forward(self, x, mask=None):
64
+ if mask is None:
65
+ mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
66
+ not_mask = ~mask
67
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
68
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
69
+ if self.normalize:
70
+ eps = 1e-6
71
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
72
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
73
+
74
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
75
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
76
+
77
+ pos_x = x_embed[:, :, :, None] / dim_t
78
+ pos_y = y_embed[:, :, :, None] / dim_t
79
+ pos_x = torch.stack(
80
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
81
+ ).flatten(3)
82
+ pos_y = torch.stack(
83
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
84
+ ).flatten(3)
85
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
86
+ return pos
87
+
88
+ def _get_activation_fn(activation):
89
+ """Return an activation function given a string"""
90
+ if activation == "relu":
91
+ return F.relu
92
+ if activation == "gelu":
93
+ return F.gelu
94
+ if activation == "glu":
95
+ return F.glu
96
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
97
+
98
+
99
+ class TransformerSALayer(nn.Module):
100
+ def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
101
+ super().__init__()
102
+ self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
103
+ # Implementation of Feedforward model - MLP
104
+ self.linear1 = nn.Linear(embed_dim, dim_mlp)
105
+ self.dropout = nn.Dropout(dropout)
106
+ self.linear2 = nn.Linear(dim_mlp, embed_dim)
107
+
108
+ self.norm1 = nn.LayerNorm(embed_dim)
109
+ self.norm2 = nn.LayerNorm(embed_dim)
110
+ self.dropout1 = nn.Dropout(dropout)
111
+ self.dropout2 = nn.Dropout(dropout)
112
+
113
+ self.activation = _get_activation_fn(activation)
114
+
115
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
116
+ return tensor if pos is None else tensor + pos
117
+
118
+ def forward(self, tgt,
119
+ tgt_mask: Optional[Tensor] = None,
120
+ tgt_key_padding_mask: Optional[Tensor] = None,
121
+ query_pos: Optional[Tensor] = None):
122
+
123
+ # self attention
124
+ tgt2 = self.norm1(tgt)
125
+ q = k = self.with_pos_embed(tgt2, query_pos)
126
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
127
+ key_padding_mask=tgt_key_padding_mask)[0]
128
+ tgt = tgt + self.dropout1(tgt2)
129
+
130
+ # ffn
131
+ tgt2 = self.norm2(tgt)
132
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
133
+ tgt = tgt + self.dropout2(tgt2)
134
+ return tgt
135
+
136
+ class Fuse_sft_block(nn.Module):
137
+ def __init__(self, in_ch, out_ch):
138
+ super().__init__()
139
+ self.encode_enc = ResBlock(2*in_ch, out_ch)
140
+
141
+ self.scale = nn.Sequential(
142
+ nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
143
+ nn.LeakyReLU(0.2, True),
144
+ nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
145
+
146
+ self.shift = nn.Sequential(
147
+ nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
148
+ nn.LeakyReLU(0.2, True),
149
+ nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
150
+
151
+ def forward(self, enc_feat, dec_feat, w=1):
152
+ enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
153
+ scale = self.scale(enc_feat)
154
+ shift = self.shift(enc_feat)
155
+ residual = w * (dec_feat * scale + shift)
156
+ out = dec_feat + residual
157
+ return out
158
+
159
+
160
+ @ARCH_REGISTRY.register()
161
+ class CodeFormer(VQAutoEncoder):
162
+ def __init__(self, dim_embd=512, n_head=8, n_layers=9,
163
+ codebook_size=1024, latent_size=256,
164
+ connect_list=['32', '64', '128', '256'],
165
+ fix_modules=['quantize','generator'], vqgan_path=None):
166
+ super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
167
+
168
+ if vqgan_path is not None:
169
+ self.load_state_dict(
170
+ torch.load(vqgan_path, map_location='cpu')['params_ema'])
171
+
172
+ if fix_modules is not None:
173
+ for module in fix_modules:
174
+ for param in getattr(self, module).parameters():
175
+ param.requires_grad = False
176
+
177
+ self.connect_list = connect_list
178
+ self.n_layers = n_layers
179
+ self.dim_embd = dim_embd
180
+ self.dim_mlp = dim_embd*2
181
+
182
+ self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
183
+ self.feat_emb = nn.Linear(256, self.dim_embd)
184
+
185
+ # transformer
186
+ self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
187
+ for _ in range(self.n_layers)])
188
+
189
+ # logits_predict head
190
+ self.idx_pred_layer = nn.Sequential(
191
+ nn.LayerNorm(dim_embd),
192
+ nn.Linear(dim_embd, codebook_size, bias=False))
193
+
194
+ self.channels = {
195
+ '16': 512,
196
+ '32': 256,
197
+ '64': 256,
198
+ '128': 128,
199
+ '256': 128,
200
+ '512': 64,
201
+ }
202
+
203
+ # after second residual block for > 16, before attn layer for ==16
204
+ self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18}
205
+ # after first residual block for > 16, before attn layer for ==16
206
+ self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21}
207
+
208
+ # fuse_convs_dict
209
+ self.fuse_convs_dict = nn.ModuleDict()
210
+ for f_size in self.connect_list:
211
+ in_ch = self.channels[f_size]
212
+ self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
213
+
214
+ def _init_weights(self, module):
215
+ if isinstance(module, (nn.Linear, nn.Embedding)):
216
+ module.weight.data.normal_(mean=0.0, std=0.02)
217
+ if isinstance(module, nn.Linear) and module.bias is not None:
218
+ module.bias.data.zero_()
219
+ elif isinstance(module, nn.LayerNorm):
220
+ module.bias.data.zero_()
221
+ module.weight.data.fill_(1.0)
222
+
223
+ def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
224
+ # ################### Encoder #####################
225
+ enc_feat_dict = {}
226
+ out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
227
+ for i, block in enumerate(self.encoder.blocks):
228
+ x = block(x)
229
+ if i in out_list:
230
+ enc_feat_dict[str(x.shape[-1])] = x.clone()
231
+
232
+ lq_feat = x
233
+ # ################# Transformer ###################
234
+ # quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
235
+ pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1)
236
+ # BCHW -> BC(HW) -> (HW)BC
237
+ feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1))
238
+ query_emb = feat_emb
239
+ # Transformer encoder
240
+ for layer in self.ft_layers:
241
+ query_emb = layer(query_emb, query_pos=pos_emb)
242
+
243
+ # output logits
244
+ logits = self.idx_pred_layer(query_emb) # (hw)bn
245
+ logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n
246
+
247
+ if code_only: # for training stage II
248
+ # logits doesn't need softmax before cross_entropy loss
249
+ return logits, lq_feat
250
+
251
+ # ################# Quantization ###################
252
+ # if self.training:
253
+ # quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
254
+ # # b(hw)c -> bc(hw) -> bchw
255
+ # quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
256
+ # ------------
257
+ soft_one_hot = F.softmax(logits, dim=2)
258
+ _, top_idx = torch.topk(soft_one_hot, 1, dim=2)
259
+ quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256])
260
+ # preserve gradients
261
+ # quant_feat = lq_feat + (quant_feat - lq_feat).detach()
262
+
263
+ if detach_16:
264
+ quant_feat = quant_feat.detach() # for training stage III
265
+ if adain:
266
+ quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
267
+
268
+ # ################## Generator ####################
269
+ x = quant_feat
270
+ fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
271
+
272
+ for i, block in enumerate(self.generator.blocks):
273
+ x = block(x)
274
+ if i in fuse_list: # fuse after i-th block
275
+ f_size = str(x.shape[-1])
276
+ if w>0:
277
+ x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
278
+ out = x
279
+ # logits doesn't need softmax before cross_entropy loss
280
+ return out, logits, lq_feat
blissful_tuner/codeformer/basicsr/archs/rrdbnet_arch.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn as nn
3
+ from torch.nn import functional as F
4
+
5
+ from codeformer.basicsr.utils.registry import ARCH_REGISTRY
6
+ from .arch_util import default_init_weights, make_layer, pixel_unshuffle
7
+
8
+
9
+ class ResidualDenseBlock(nn.Module):
10
+ """Residual Dense Block.
11
+
12
+ Used in RRDB block in ESRGAN.
13
+
14
+ Args:
15
+ num_feat (int): Channel number of intermediate features.
16
+ num_grow_ch (int): Channels for each growth.
17
+ """
18
+
19
+ def __init__(self, num_feat=64, num_grow_ch=32):
20
+ super(ResidualDenseBlock, self).__init__()
21
+ self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
22
+ self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
23
+ self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
24
+ self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
25
+ self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
26
+
27
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
28
+
29
+ # initialization
30
+ default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
31
+
32
+ def forward(self, x):
33
+ x1 = self.lrelu(self.conv1(x))
34
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
35
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
36
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
37
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
38
+ # Emperically, we use 0.2 to scale the residual for better performance
39
+ return x5 * 0.2 + x
40
+
41
+
42
+ class RRDB(nn.Module):
43
+ """Residual in Residual Dense Block.
44
+
45
+ Used in RRDB-Net in ESRGAN.
46
+
47
+ Args:
48
+ num_feat (int): Channel number of intermediate features.
49
+ num_grow_ch (int): Channels for each growth.
50
+ """
51
+
52
+ def __init__(self, num_feat, num_grow_ch=32):
53
+ super(RRDB, self).__init__()
54
+ self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
55
+ self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
56
+ self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
57
+
58
+ def forward(self, x):
59
+ out = self.rdb1(x)
60
+ out = self.rdb2(out)
61
+ out = self.rdb3(out)
62
+ # Emperically, we use 0.2 to scale the residual for better performance
63
+ return out * 0.2 + x
64
+
65
+
66
+ @ARCH_REGISTRY.register()
67
+ class RRDBNet(nn.Module):
68
+ """Networks consisting of Residual in Residual Dense Block, which is used
69
+ in ESRGAN.
70
+
71
+ ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
72
+
73
+ We extend ESRGAN for scale x2 and scale x1.
74
+ Note: This is one option for scale 1, scale 2 in RRDBNet.
75
+ We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
76
+ and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
77
+
78
+ Args:
79
+ num_in_ch (int): Channel number of inputs.
80
+ num_out_ch (int): Channel number of outputs.
81
+ num_feat (int): Channel number of intermediate features.
82
+ Default: 64
83
+ num_block (int): Block number in the trunk network. Defaults: 23
84
+ num_grow_ch (int): Channels for each growth. Default: 32.
85
+ """
86
+
87
+ def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
88
+ super(RRDBNet, self).__init__()
89
+ self.scale = scale
90
+ if scale == 2:
91
+ num_in_ch = num_in_ch * 4
92
+ elif scale == 1:
93
+ num_in_ch = num_in_ch * 16
94
+ self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
95
+ self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
96
+ self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
97
+ # upsample
98
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
99
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
100
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
101
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
102
+
103
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
104
+
105
+ def forward(self, x):
106
+ if self.scale == 2:
107
+ feat = pixel_unshuffle(x, scale=2)
108
+ elif self.scale == 1:
109
+ feat = pixel_unshuffle(x, scale=4)
110
+ else:
111
+ feat = x
112
+ feat = self.conv_first(feat)
113
+ body_feat = self.conv_body(self.body(feat))
114
+ feat = feat + body_feat
115
+ # upsample
116
+ feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
117
+ feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
118
+ out = self.conv_last(self.lrelu(self.conv_hr(feat)))
119
+ return out
blissful_tuner/codeformer/basicsr/archs/vgg_arch.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from collections import OrderedDict
4
+ from torch import nn as nn
5
+ from torchvision.models import vgg as vgg
6
+
7
+ from codeformer.basicsr.utils.registry import ARCH_REGISTRY
8
+
9
+ VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth'
10
+ NAMES = {
11
+ 'vgg11': [
12
+ 'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2',
13
+ 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2',
14
+ 'pool5'
15
+ ],
16
+ 'vgg13': [
17
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
18
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4',
19
+ 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5'
20
+ ],
21
+ 'vgg16': [
22
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
23
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2',
24
+ 'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
25
+ 'pool5'
26
+ ],
27
+ 'vgg19': [
28
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
29
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1',
30
+ 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1',
31
+ 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5'
32
+ ]
33
+ }
34
+
35
+
36
+ def insert_bn(names):
37
+ """Insert bn layer after each conv.
38
+
39
+ Args:
40
+ names (list): The list of layer names.
41
+
42
+ Returns:
43
+ list: The list of layer names with bn layers.
44
+ """
45
+ names_bn = []
46
+ for name in names:
47
+ names_bn.append(name)
48
+ if 'conv' in name:
49
+ position = name.replace('conv', '')
50
+ names_bn.append('bn' + position)
51
+ return names_bn
52
+
53
+
54
+ @ARCH_REGISTRY.register()
55
+ class VGGFeatureExtractor(nn.Module):
56
+ """VGG network for feature extraction.
57
+
58
+ In this implementation, we allow users to choose whether use normalization
59
+ in the input feature and the type of vgg network. Note that the pretrained
60
+ path must fit the vgg type.
61
+
62
+ Args:
63
+ layer_name_list (list[str]): Forward function returns the corresponding
64
+ features according to the layer_name_list.
65
+ Example: {'relu1_1', 'relu2_1', 'relu3_1'}.
66
+ vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
67
+ use_input_norm (bool): If True, normalize the input image. Importantly,
68
+ the input feature must in the range [0, 1]. Default: True.
69
+ range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
70
+ Default: False.
71
+ requires_grad (bool): If true, the parameters of VGG network will be
72
+ optimized. Default: False.
73
+ remove_pooling (bool): If true, the max pooling operations in VGG net
74
+ will be removed. Default: False.
75
+ pooling_stride (int): The stride of max pooling operation. Default: 2.
76
+ """
77
+
78
+ def __init__(self,
79
+ layer_name_list,
80
+ vgg_type='vgg19',
81
+ use_input_norm=True,
82
+ range_norm=False,
83
+ requires_grad=False,
84
+ remove_pooling=False,
85
+ pooling_stride=2):
86
+ super(VGGFeatureExtractor, self).__init__()
87
+
88
+ self.layer_name_list = layer_name_list
89
+ self.use_input_norm = use_input_norm
90
+ self.range_norm = range_norm
91
+
92
+ self.names = NAMES[vgg_type.replace('_bn', '')]
93
+ if 'bn' in vgg_type:
94
+ self.names = insert_bn(self.names)
95
+
96
+ # only borrow layers that will be used to avoid unused params
97
+ max_idx = 0
98
+ for v in layer_name_list:
99
+ idx = self.names.index(v)
100
+ if idx > max_idx:
101
+ max_idx = idx
102
+
103
+ if os.path.exists(VGG_PRETRAIN_PATH):
104
+ vgg_net = getattr(vgg, vgg_type)(pretrained=False)
105
+ state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage)
106
+ vgg_net.load_state_dict(state_dict)
107
+ else:
108
+ vgg_net = getattr(vgg, vgg_type)(pretrained=True)
109
+
110
+ features = vgg_net.features[:max_idx + 1]
111
+
112
+ modified_net = OrderedDict()
113
+ for k, v in zip(self.names, features):
114
+ if 'pool' in k:
115
+ # if remove_pooling is true, pooling operation will be removed
116
+ if remove_pooling:
117
+ continue
118
+ else:
119
+ # in some cases, we may want to change the default stride
120
+ modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride)
121
+ else:
122
+ modified_net[k] = v
123
+
124
+ self.vgg_net = nn.Sequential(modified_net)
125
+
126
+ if not requires_grad:
127
+ self.vgg_net.eval()
128
+ for param in self.parameters():
129
+ param.requires_grad = False
130
+ else:
131
+ self.vgg_net.train()
132
+ for param in self.parameters():
133
+ param.requires_grad = True
134
+
135
+ if self.use_input_norm:
136
+ # the mean is for image with range [0, 1]
137
+ self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
138
+ # the std is for image with range [0, 1]
139
+ self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
140
+
141
+ def forward(self, x):
142
+ """Forward function.
143
+
144
+ Args:
145
+ x (Tensor): Input tensor with shape (n, c, h, w).
146
+
147
+ Returns:
148
+ Tensor: Forward results.
149
+ """
150
+ if self.range_norm:
151
+ x = (x + 1) / 2
152
+ if self.use_input_norm:
153
+ x = (x - self.mean) / self.std
154
+ output = {}
155
+
156
+ for key, layer in self.vgg_net._modules.items():
157
+ x = layer(x)
158
+ if key in self.layer_name_list:
159
+ output[key] = x.clone()
160
+
161
+ return output
blissful_tuner/codeformer/basicsr/archs/vqgan_arch.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ VQGAN code, adapted from the original created by the Unleashing Transformers authors:
3
+ https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
4
+
5
+ '''
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import copy
11
+ from codeformer.basicsr.utils import get_root_logger
12
+ from codeformer.basicsr.utils.registry import ARCH_REGISTRY
13
+
14
+ def normalize(in_channels):
15
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
16
+
17
+
18
+ @torch.jit.script
19
+ def swish(x):
20
+ return x*torch.sigmoid(x)
21
+
22
+
23
+ # Define VQVAE classes
24
+ class VectorQuantizer(nn.Module):
25
+ def __init__(self, codebook_size, emb_dim, beta):
26
+ super(VectorQuantizer, self).__init__()
27
+ self.codebook_size = codebook_size # number of embeddings
28
+ self.emb_dim = emb_dim # dimension of embedding
29
+ self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
30
+ self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
31
+ self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
32
+
33
+ def forward(self, z):
34
+ # reshape z -> (batch, height, width, channel) and flatten
35
+ z = z.permute(0, 2, 3, 1).contiguous()
36
+ z_flattened = z.view(-1, self.emb_dim)
37
+
38
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
39
+ d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \
40
+ 2 * torch.matmul(z_flattened, self.embedding.weight.t())
41
+
42
+ mean_distance = torch.mean(d)
43
+ # find closest encodings
44
+ min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
45
+ # min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False)
46
+ # [0-1], higher score, higher confidence
47
+ # min_encoding_scores = torch.exp(-min_encoding_scores/10)
48
+
49
+ min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z)
50
+ min_encodings.scatter_(1, min_encoding_indices, 1)
51
+
52
+ # get quantized latent vectors
53
+ z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
54
+ # compute loss for embedding
55
+ loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
56
+ # preserve gradients
57
+ z_q = z + (z_q - z).detach()
58
+
59
+ # perplexity
60
+ e_mean = torch.mean(min_encodings, dim=0)
61
+ perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
62
+ # reshape back to match original input shape
63
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
64
+
65
+ return z_q, loss, {
66
+ "perplexity": perplexity,
67
+ "min_encodings": min_encodings,
68
+ "min_encoding_indices": min_encoding_indices,
69
+ "mean_distance": mean_distance
70
+ }
71
+
72
+ def get_codebook_feat(self, indices, shape):
73
+ # input indices: batch*token_num -> (batch*token_num)*1
74
+ # shape: batch, height, width, channel
75
+ indices = indices.view(-1,1)
76
+ min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
77
+ min_encodings.scatter_(1, indices, 1)
78
+ # get quantized latent vectors
79
+ z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
80
+
81
+ if shape is not None: # reshape back to match original input shape
82
+ z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
83
+
84
+ return z_q
85
+
86
+
87
+ class GumbelQuantizer(nn.Module):
88
+ def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0):
89
+ super().__init__()
90
+ self.codebook_size = codebook_size # number of embeddings
91
+ self.emb_dim = emb_dim # dimension of embedding
92
+ self.straight_through = straight_through
93
+ self.temperature = temp_init
94
+ self.kl_weight = kl_weight
95
+ self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits
96
+ self.embed = nn.Embedding(codebook_size, emb_dim)
97
+
98
+ def forward(self, z):
99
+ hard = self.straight_through if self.training else True
100
+
101
+ logits = self.proj(z)
102
+
103
+ soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
104
+
105
+ z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
106
+
107
+ # + kl divergence to the prior loss
108
+ qy = F.softmax(logits, dim=1)
109
+ diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
110
+ min_encoding_indices = soft_one_hot.argmax(dim=1)
111
+
112
+ return z_q, diff, {
113
+ "min_encoding_indices": min_encoding_indices
114
+ }
115
+
116
+
117
+ class Downsample(nn.Module):
118
+ def __init__(self, in_channels):
119
+ super().__init__()
120
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
121
+
122
+ def forward(self, x):
123
+ pad = (0, 1, 0, 1)
124
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
125
+ x = self.conv(x)
126
+ return x
127
+
128
+
129
+ class Upsample(nn.Module):
130
+ def __init__(self, in_channels):
131
+ super().__init__()
132
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
133
+
134
+ def forward(self, x):
135
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
136
+ x = self.conv(x)
137
+
138
+ return x
139
+
140
+
141
+ class ResBlock(nn.Module):
142
+ def __init__(self, in_channels, out_channels=None):
143
+ super(ResBlock, self).__init__()
144
+ self.in_channels = in_channels
145
+ self.out_channels = in_channels if out_channels is None else out_channels
146
+ self.norm1 = normalize(in_channels)
147
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
148
+ self.norm2 = normalize(out_channels)
149
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
150
+ if self.in_channels != self.out_channels:
151
+ self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
152
+
153
+ def forward(self, x_in):
154
+ x = x_in
155
+ x = self.norm1(x)
156
+ x = swish(x)
157
+ x = self.conv1(x)
158
+ x = self.norm2(x)
159
+ x = swish(x)
160
+ x = self.conv2(x)
161
+ if self.in_channels != self.out_channels:
162
+ x_in = self.conv_out(x_in)
163
+
164
+ return x + x_in
165
+
166
+
167
+ class AttnBlock(nn.Module):
168
+ def __init__(self, in_channels):
169
+ super().__init__()
170
+ self.in_channels = in_channels
171
+
172
+ self.norm = normalize(in_channels)
173
+ self.q = torch.nn.Conv2d(
174
+ in_channels,
175
+ in_channels,
176
+ kernel_size=1,
177
+ stride=1,
178
+ padding=0
179
+ )
180
+ self.k = torch.nn.Conv2d(
181
+ in_channels,
182
+ in_channels,
183
+ kernel_size=1,
184
+ stride=1,
185
+ padding=0
186
+ )
187
+ self.v = torch.nn.Conv2d(
188
+ in_channels,
189
+ in_channels,
190
+ kernel_size=1,
191
+ stride=1,
192
+ padding=0
193
+ )
194
+ self.proj_out = torch.nn.Conv2d(
195
+ in_channels,
196
+ in_channels,
197
+ kernel_size=1,
198
+ stride=1,
199
+ padding=0
200
+ )
201
+
202
+ def forward(self, x):
203
+ h_ = x
204
+ h_ = self.norm(h_)
205
+ q = self.q(h_)
206
+ k = self.k(h_)
207
+ v = self.v(h_)
208
+
209
+ # compute attention
210
+ b, c, h, w = q.shape
211
+ q = q.reshape(b, c, h*w)
212
+ q = q.permute(0, 2, 1)
213
+ k = k.reshape(b, c, h*w)
214
+ w_ = torch.bmm(q, k)
215
+ w_ = w_ * (int(c)**(-0.5))
216
+ w_ = F.softmax(w_, dim=2)
217
+
218
+ # attend to values
219
+ v = v.reshape(b, c, h*w)
220
+ w_ = w_.permute(0, 2, 1)
221
+ h_ = torch.bmm(v, w_)
222
+ h_ = h_.reshape(b, c, h, w)
223
+
224
+ h_ = self.proj_out(h_)
225
+
226
+ return x+h_
227
+
228
+
229
+ class Encoder(nn.Module):
230
+ def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions):
231
+ super().__init__()
232
+ self.nf = nf
233
+ self.num_resolutions = len(ch_mult)
234
+ self.num_res_blocks = num_res_blocks
235
+ self.resolution = resolution
236
+ self.attn_resolutions = attn_resolutions
237
+
238
+ curr_res = self.resolution
239
+ in_ch_mult = (1,)+tuple(ch_mult)
240
+
241
+ blocks = []
242
+ # initial convultion
243
+ blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
244
+
245
+ # residual and downsampling blocks, with attention on smaller res (16x16)
246
+ for i in range(self.num_resolutions):
247
+ block_in_ch = nf * in_ch_mult[i]
248
+ block_out_ch = nf * ch_mult[i]
249
+ for _ in range(self.num_res_blocks):
250
+ blocks.append(ResBlock(block_in_ch, block_out_ch))
251
+ block_in_ch = block_out_ch
252
+ if curr_res in attn_resolutions:
253
+ blocks.append(AttnBlock(block_in_ch))
254
+
255
+ if i != self.num_resolutions - 1:
256
+ blocks.append(Downsample(block_in_ch))
257
+ curr_res = curr_res // 2
258
+
259
+ # non-local attention block
260
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
261
+ blocks.append(AttnBlock(block_in_ch))
262
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
263
+
264
+ # normalise and convert to latent size
265
+ blocks.append(normalize(block_in_ch))
266
+ blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1))
267
+ self.blocks = nn.ModuleList(blocks)
268
+
269
+ def forward(self, x):
270
+ for block in self.blocks:
271
+ x = block(x)
272
+
273
+ return x
274
+
275
+
276
+ class Generator(nn.Module):
277
+ def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
278
+ super().__init__()
279
+ self.nf = nf
280
+ self.ch_mult = ch_mult
281
+ self.num_resolutions = len(self.ch_mult)
282
+ self.num_res_blocks = res_blocks
283
+ self.resolution = img_size
284
+ self.attn_resolutions = attn_resolutions
285
+ self.in_channels = emb_dim
286
+ self.out_channels = 3
287
+ block_in_ch = self.nf * self.ch_mult[-1]
288
+ curr_res = self.resolution // 2 ** (self.num_resolutions-1)
289
+
290
+ blocks = []
291
+ # initial conv
292
+ blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1))
293
+
294
+ # non-local attention block
295
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
296
+ blocks.append(AttnBlock(block_in_ch))
297
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
298
+
299
+ for i in reversed(range(self.num_resolutions)):
300
+ block_out_ch = self.nf * self.ch_mult[i]
301
+
302
+ for _ in range(self.num_res_blocks):
303
+ blocks.append(ResBlock(block_in_ch, block_out_ch))
304
+ block_in_ch = block_out_ch
305
+
306
+ if curr_res in self.attn_resolutions:
307
+ blocks.append(AttnBlock(block_in_ch))
308
+
309
+ if i != 0:
310
+ blocks.append(Upsample(block_in_ch))
311
+ curr_res = curr_res * 2
312
+
313
+ blocks.append(normalize(block_in_ch))
314
+ blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
315
+
316
+ self.blocks = nn.ModuleList(blocks)
317
+
318
+
319
+ def forward(self, x):
320
+ for block in self.blocks:
321
+ x = block(x)
322
+
323
+ return x
324
+
325
+
326
+ @ARCH_REGISTRY.register()
327
+ class VQAutoEncoder(nn.Module):
328
+ def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256,
329
+ beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
330
+ super().__init__()
331
+ logger = get_root_logger()
332
+ self.in_channels = 3
333
+ self.nf = nf
334
+ self.n_blocks = res_blocks
335
+ self.codebook_size = codebook_size
336
+ self.embed_dim = emb_dim
337
+ self.ch_mult = ch_mult
338
+ self.resolution = img_size
339
+ self.attn_resolutions = attn_resolutions
340
+ self.quantizer_type = quantizer
341
+ self.encoder = Encoder(
342
+ self.in_channels,
343
+ self.nf,
344
+ self.embed_dim,
345
+ self.ch_mult,
346
+ self.n_blocks,
347
+ self.resolution,
348
+ self.attn_resolutions
349
+ )
350
+ if self.quantizer_type == "nearest":
351
+ self.beta = beta #0.25
352
+ self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta)
353
+ elif self.quantizer_type == "gumbel":
354
+ self.gumbel_num_hiddens = emb_dim
355
+ self.straight_through = gumbel_straight_through
356
+ self.kl_weight = gumbel_kl_weight
357
+ self.quantize = GumbelQuantizer(
358
+ self.codebook_size,
359
+ self.embed_dim,
360
+ self.gumbel_num_hiddens,
361
+ self.straight_through,
362
+ self.kl_weight
363
+ )
364
+ self.generator = Generator(
365
+ self.nf,
366
+ self.embed_dim,
367
+ self.ch_mult,
368
+ self.n_blocks,
369
+ self.resolution,
370
+ self.attn_resolutions
371
+ )
372
+
373
+ if model_path is not None:
374
+ chkpt = torch.load(model_path, map_location='cpu')
375
+ if 'params_ema' in chkpt:
376
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema'])
377
+ logger.info(f'vqgan is loaded from: {model_path} [params_ema]')
378
+ elif 'params' in chkpt:
379
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
380
+ logger.info(f'vqgan is loaded from: {model_path} [params]')
381
+ else:
382
+ raise ValueError(f'Wrong params!')
383
+
384
+
385
+ def forward(self, x):
386
+ x = self.encoder(x)
387
+ quant, codebook_loss, quant_stats = self.quantize(x)
388
+ x = self.generator(quant)
389
+ return x, codebook_loss, quant_stats
390
+
391
+
392
+
393
+ # patch based discriminator
394
+ @ARCH_REGISTRY.register()
395
+ class VQGANDiscriminator(nn.Module):
396
+ def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
397
+ super().__init__()
398
+
399
+ layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
400
+ ndf_mult = 1
401
+ ndf_mult_prev = 1
402
+ for n in range(1, n_layers): # gradually increase the number of filters
403
+ ndf_mult_prev = ndf_mult
404
+ ndf_mult = min(2 ** n, 8)
405
+ layers += [
406
+ nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False),
407
+ nn.BatchNorm2d(ndf * ndf_mult),
408
+ nn.LeakyReLU(0.2, True)
409
+ ]
410
+
411
+ ndf_mult_prev = ndf_mult
412
+ ndf_mult = min(2 ** n_layers, 8)
413
+
414
+ layers += [
415
+ nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False),
416
+ nn.BatchNorm2d(ndf * ndf_mult),
417
+ nn.LeakyReLU(0.2, True)
418
+ ]
419
+
420
+ layers += [
421
+ nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map
422
+ self.main = nn.Sequential(*layers)
423
+
424
+ if model_path is not None:
425
+ chkpt = torch.load(model_path, map_location='cpu')
426
+ if 'params_d' in chkpt:
427
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d'])
428
+ elif 'params' in chkpt:
429
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
430
+ else:
431
+ raise ValueError(f'Wrong params!')
432
+
433
+ def forward(self, x):
434
+ return self.main(x)
blissful_tuner/codeformer/basicsr/data/__init__.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import numpy as np
3
+ import random
4
+ import torch
5
+ import torch.utils.data
6
+ from copy import deepcopy
7
+ from functools import partial
8
+ from os import path as osp
9
+
10
+ from codeformer.basicsr.data.prefetch_dataloader import PrefetchDataLoader
11
+ from codeformer.basicsr.utils import get_root_logger, scandir
12
+ from codeformer.basicsr.utils.dist_util import get_dist_info
13
+ from codeformer.basicsr.utils.registry import DATASET_REGISTRY
14
+
15
+ __all__ = ['build_dataset', 'build_dataloader']
16
+
17
+ # automatically scan and import dataset modules for registry
18
+ # scan all the files under the data folder with '_dataset' in file names
19
+ data_folder = osp.dirname(osp.abspath(__file__))
20
+ dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
21
+ # import all the dataset modules
22
+ _dataset_modules = [importlib.import_module(f'codeformer.basicsr.data.{file_name}') for file_name in dataset_filenames]
23
+
24
+
25
+ def build_dataset(dataset_opt):
26
+ """Build dataset from options.
27
+
28
+ Args:
29
+ dataset_opt (dict): Configuration for dataset. It must constain:
30
+ name (str): Dataset name.
31
+ type (str): Dataset type.
32
+ """
33
+ dataset_opt = deepcopy(dataset_opt)
34
+ dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt)
35
+ logger = get_root_logger()
36
+ logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} ' 'is built.')
37
+ return dataset
38
+
39
+
40
+ def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None):
41
+ """Build dataloader.
42
+
43
+ Args:
44
+ dataset (torch.utils.data.Dataset): Dataset.
45
+ dataset_opt (dict): Dataset options. It contains the following keys:
46
+ phase (str): 'train' or 'val'.
47
+ num_worker_per_gpu (int): Number of workers for each GPU.
48
+ batch_size_per_gpu (int): Training batch size for each GPU.
49
+ num_gpu (int): Number of GPUs. Used only in the train phase.
50
+ Default: 1.
51
+ dist (bool): Whether in distributed training. Used only in the train
52
+ phase. Default: False.
53
+ sampler (torch.utils.data.sampler): Data sampler. Default: None.
54
+ seed (int | None): Seed. Default: None
55
+ """
56
+ phase = dataset_opt['phase']
57
+ rank, _ = get_dist_info()
58
+ if phase == 'train':
59
+ if dist: # distributed training
60
+ batch_size = dataset_opt['batch_size_per_gpu']
61
+ num_workers = dataset_opt['num_worker_per_gpu']
62
+ else: # non-distributed training
63
+ multiplier = 1 if num_gpu == 0 else num_gpu
64
+ batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
65
+ num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
66
+ dataloader_args = dict(
67
+ dataset=dataset,
68
+ batch_size=batch_size,
69
+ shuffle=False,
70
+ num_workers=num_workers,
71
+ sampler=sampler,
72
+ drop_last=True)
73
+ if sampler is None:
74
+ dataloader_args['shuffle'] = True
75
+ dataloader_args['worker_init_fn'] = partial(
76
+ worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None
77
+ elif phase in ['val', 'test']: # validation
78
+ dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
79
+ else:
80
+ raise ValueError(f'Wrong dataset phase: {phase}. ' "Supported ones are 'train', 'val' and 'test'.")
81
+
82
+ dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
83
+
84
+ prefetch_mode = dataset_opt.get('prefetch_mode')
85
+ if prefetch_mode == 'cpu': # CPUPrefetcher
86
+ num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
87
+ logger = get_root_logger()
88
+ logger.info(f'Use {prefetch_mode} prefetch dataloader: ' f'num_prefetch_queue = {num_prefetch_queue}')
89
+ return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args)
90
+ else:
91
+ # prefetch_mode=None: Normal dataloader
92
+ # prefetch_mode='cuda': dataloader for CUDAPrefetcher
93
+ return torch.utils.data.DataLoader(**dataloader_args)
94
+
95
+
96
+ def worker_init_fn(worker_id, num_workers, rank, seed):
97
+ # Set the worker seed to num_workers * rank + worker_id + seed
98
+ worker_seed = num_workers * rank + worker_id + seed
99
+ np.random.seed(worker_seed)
100
+ random.seed(worker_seed)
blissful_tuner/codeformer/basicsr/data/data_sampler.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch.utils.data.sampler import Sampler
4
+
5
+
6
+ class EnlargedSampler(Sampler):
7
+ """Sampler that restricts data loading to a subset of the dataset.
8
+
9
+ Modified from torch.utils.data.distributed.DistributedSampler
10
+ Support enlarging the dataset for iteration-based training, for saving
11
+ time when restart the dataloader after each epoch
12
+
13
+ Args:
14
+ dataset (torch.utils.data.Dataset): Dataset used for sampling.
15
+ num_replicas (int | None): Number of processes participating in
16
+ the training. It is usually the world_size.
17
+ rank (int | None): Rank of the current process within num_replicas.
18
+ ratio (int): Enlarging ratio. Default: 1.
19
+ """
20
+
21
+ def __init__(self, dataset, num_replicas, rank, ratio=1):
22
+ self.dataset = dataset
23
+ self.num_replicas = num_replicas
24
+ self.rank = rank
25
+ self.epoch = 0
26
+ self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas)
27
+ self.total_size = self.num_samples * self.num_replicas
28
+
29
+ def __iter__(self):
30
+ # deterministically shuffle based on epoch
31
+ g = torch.Generator()
32
+ g.manual_seed(self.epoch)
33
+ indices = torch.randperm(self.total_size, generator=g).tolist()
34
+
35
+ dataset_size = len(self.dataset)
36
+ indices = [v % dataset_size for v in indices]
37
+
38
+ # subsample
39
+ indices = indices[self.rank:self.total_size:self.num_replicas]
40
+ assert len(indices) == self.num_samples
41
+
42
+ return iter(indices)
43
+
44
+ def __len__(self):
45
+ return self.num_samples
46
+
47
+ def set_epoch(self, epoch):
48
+ self.epoch = epoch
blissful_tuner/codeformer/basicsr/data/data_util.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ import numpy as np
4
+ import torch
5
+ from os import path as osp
6
+ from PIL import Image, ImageDraw
7
+ from torch.nn import functional as F
8
+
9
+ from codeformer.basicsr.data.transforms import mod_crop
10
+ from codeformer.basicsr.utils import img2tensor, scandir
11
+
12
+
13
+ def read_img_seq(path, require_mod_crop=False, scale=1):
14
+ """Read a sequence of images from a given folder path.
15
+
16
+ Args:
17
+ path (list[str] | str): List of image paths or image folder path.
18
+ require_mod_crop (bool): Require mod crop for each image.
19
+ Default: False.
20
+ scale (int): Scale factor for mod_crop. Default: 1.
21
+
22
+ Returns:
23
+ Tensor: size (t, c, h, w), RGB, [0, 1].
24
+ """
25
+ if isinstance(path, list):
26
+ img_paths = path
27
+ else:
28
+ img_paths = sorted(list(scandir(path, full_path=True)))
29
+ imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
30
+ if require_mod_crop:
31
+ imgs = [mod_crop(img, scale) for img in imgs]
32
+ imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
33
+ imgs = torch.stack(imgs, dim=0)
34
+ return imgs
35
+
36
+
37
+ def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='reflection'):
38
+ """Generate an index list for reading `num_frames` frames from a sequence
39
+ of images.
40
+
41
+ Args:
42
+ crt_idx (int): Current center index.
43
+ max_frame_num (int): Max number of the sequence of images (from 1).
44
+ num_frames (int): Reading num_frames frames.
45
+ padding (str): Padding mode, one of
46
+ 'replicate' | 'reflection' | 'reflection_circle' | 'circle'
47
+ Examples: current_idx = 0, num_frames = 5
48
+ The generated frame indices under different padding mode:
49
+ replicate: [0, 0, 0, 1, 2]
50
+ reflection: [2, 1, 0, 1, 2]
51
+ reflection_circle: [4, 3, 0, 1, 2]
52
+ circle: [3, 4, 0, 1, 2]
53
+
54
+ Returns:
55
+ list[int]: A list of indices.
56
+ """
57
+ assert num_frames % 2 == 1, 'num_frames should be an odd number.'
58
+ assert padding in ('replicate', 'reflection', 'reflection_circle', 'circle'), f'Wrong padding mode: {padding}.'
59
+
60
+ max_frame_num = max_frame_num - 1 # start from 0
61
+ num_pad = num_frames // 2
62
+
63
+ indices = []
64
+ for i in range(crt_idx - num_pad, crt_idx + num_pad + 1):
65
+ if i < 0:
66
+ if padding == 'replicate':
67
+ pad_idx = 0
68
+ elif padding == 'reflection':
69
+ pad_idx = -i
70
+ elif padding == 'reflection_circle':
71
+ pad_idx = crt_idx + num_pad - i
72
+ else:
73
+ pad_idx = num_frames + i
74
+ elif i > max_frame_num:
75
+ if padding == 'replicate':
76
+ pad_idx = max_frame_num
77
+ elif padding == 'reflection':
78
+ pad_idx = max_frame_num * 2 - i
79
+ elif padding == 'reflection_circle':
80
+ pad_idx = (crt_idx - num_pad) - (i - max_frame_num)
81
+ else:
82
+ pad_idx = i - num_frames
83
+ else:
84
+ pad_idx = i
85
+ indices.append(pad_idx)
86
+ return indices
87
+
88
+
89
+ def paired_paths_from_lmdb(folders, keys):
90
+ """Generate paired paths from lmdb files.
91
+
92
+ Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is:
93
+
94
+ lq.lmdb
95
+ ├── data.mdb
96
+ ├── lock.mdb
97
+ ├── meta_info.txt
98
+
99
+ The data.mdb and lock.mdb are standard lmdb files and you can refer to
100
+ https://lmdb.readthedocs.io/en/release/ for more details.
101
+
102
+ The meta_info.txt is a specified txt file to record the meta information
103
+ of our datasets. It will be automatically created when preparing
104
+ datasets by our provided dataset tools.
105
+ Each line in the txt file records
106
+ 1)image name (with extension),
107
+ 2)image shape,
108
+ 3)compression level, separated by a white space.
109
+ Example: `baboon.png (120,125,3) 1`
110
+
111
+ We use the image name without extension as the lmdb key.
112
+ Note that we use the same key for the corresponding lq and gt images.
113
+
114
+ Args:
115
+ folders (list[str]): A list of folder path. The order of list should
116
+ be [input_folder, gt_folder].
117
+ keys (list[str]): A list of keys identifying folders. The order should
118
+ be in consistent with folders, e.g., ['lq', 'gt'].
119
+ Note that this key is different from lmdb keys.
120
+
121
+ Returns:
122
+ list[str]: Returned path list.
123
+ """
124
+ assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
125
+ f'But got {len(folders)}')
126
+ assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
127
+ input_folder, gt_folder = folders
128
+ input_key, gt_key = keys
129
+
130
+ if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')):
131
+ raise ValueError(f'{input_key} folder and {gt_key} folder should both in lmdb '
132
+ f'formats. But received {input_key}: {input_folder}; '
133
+ f'{gt_key}: {gt_folder}')
134
+ # ensure that the two meta_info files are the same
135
+ with open(osp.join(input_folder, 'meta_info.txt')) as fin:
136
+ input_lmdb_keys = [line.split('.')[0] for line in fin]
137
+ with open(osp.join(gt_folder, 'meta_info.txt')) as fin:
138
+ gt_lmdb_keys = [line.split('.')[0] for line in fin]
139
+ if set(input_lmdb_keys) != set(gt_lmdb_keys):
140
+ raise ValueError(f'Keys in {input_key}_folder and {gt_key}_folder are different.')
141
+ else:
142
+ paths = []
143
+ for lmdb_key in sorted(input_lmdb_keys):
144
+ paths.append(dict([(f'{input_key}_path', lmdb_key), (f'{gt_key}_path', lmdb_key)]))
145
+ return paths
146
+
147
+
148
+ def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmpl):
149
+ """Generate paired paths from an meta information file.
150
+
151
+ Each line in the meta information file contains the image names and
152
+ image shape (usually for gt), separated by a white space.
153
+
154
+ Example of an meta information file:
155
+ ```
156
+ 0001_s001.png (480,480,3)
157
+ 0001_s002.png (480,480,3)
158
+ ```
159
+
160
+ Args:
161
+ folders (list[str]): A list of folder path. The order of list should
162
+ be [input_folder, gt_folder].
163
+ keys (list[str]): A list of keys identifying folders. The order should
164
+ be in consistent with folders, e.g., ['lq', 'gt'].
165
+ meta_info_file (str): Path to the meta information file.
166
+ filename_tmpl (str): Template for each filename. Note that the
167
+ template excludes the file extension. Usually the filename_tmpl is
168
+ for files in the input folder.
169
+
170
+ Returns:
171
+ list[str]: Returned path list.
172
+ """
173
+ assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
174
+ f'But got {len(folders)}')
175
+ assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
176
+ input_folder, gt_folder = folders
177
+ input_key, gt_key = keys
178
+
179
+ with open(meta_info_file, 'r') as fin:
180
+ gt_names = [line.split(' ')[0] for line in fin]
181
+
182
+ paths = []
183
+ for gt_name in gt_names:
184
+ basename, ext = osp.splitext(osp.basename(gt_name))
185
+ input_name = f'{filename_tmpl.format(basename)}{ext}'
186
+ input_path = osp.join(input_folder, input_name)
187
+ gt_path = osp.join(gt_folder, gt_name)
188
+ paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
189
+ return paths
190
+
191
+
192
+ def paired_paths_from_folder(folders, keys, filename_tmpl):
193
+ """Generate paired paths from folders.
194
+
195
+ Args:
196
+ folders (list[str]): A list of folder path. The order of list should
197
+ be [input_folder, gt_folder].
198
+ keys (list[str]): A list of keys identifying folders. The order should
199
+ be in consistent with folders, e.g., ['lq', 'gt'].
200
+ filename_tmpl (str): Template for each filename. Note that the
201
+ template excludes the file extension. Usually the filename_tmpl is
202
+ for files in the input folder.
203
+
204
+ Returns:
205
+ list[str]: Returned path list.
206
+ """
207
+ assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
208
+ f'But got {len(folders)}')
209
+ assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
210
+ input_folder, gt_folder = folders
211
+ input_key, gt_key = keys
212
+
213
+ input_paths = list(scandir(input_folder))
214
+ gt_paths = list(scandir(gt_folder))
215
+ assert len(input_paths) == len(gt_paths), (f'{input_key} and {gt_key} datasets have different number of images: '
216
+ f'{len(input_paths)}, {len(gt_paths)}.')
217
+ paths = []
218
+ for gt_path in gt_paths:
219
+ basename, ext = osp.splitext(osp.basename(gt_path))
220
+ input_name = f'{filename_tmpl.format(basename)}{ext}'
221
+ input_path = osp.join(input_folder, input_name)
222
+ assert input_name in input_paths, (f'{input_name} is not in ' f'{input_key}_paths.')
223
+ gt_path = osp.join(gt_folder, gt_path)
224
+ paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
225
+ return paths
226
+
227
+
228
+ def paths_from_folder(folder):
229
+ """Generate paths from folder.
230
+
231
+ Args:
232
+ folder (str): Folder path.
233
+
234
+ Returns:
235
+ list[str]: Returned path list.
236
+ """
237
+
238
+ paths = list(scandir(folder))
239
+ paths = [osp.join(folder, path) for path in paths]
240
+ return paths
241
+
242
+
243
+ def paths_from_lmdb(folder):
244
+ """Generate paths from lmdb.
245
+
246
+ Args:
247
+ folder (str): Folder path.
248
+
249
+ Returns:
250
+ list[str]: Returned path list.
251
+ """
252
+ if not folder.endswith('.lmdb'):
253
+ raise ValueError(f'Folder {folder}folder should in lmdb format.')
254
+ with open(osp.join(folder, 'meta_info.txt')) as fin:
255
+ paths = [line.split('.')[0] for line in fin]
256
+ return paths
257
+
258
+
259
+ def generate_gaussian_kernel(kernel_size=13, sigma=1.6):
260
+ """Generate Gaussian kernel used in `duf_downsample`.
261
+
262
+ Args:
263
+ kernel_size (int): Kernel size. Default: 13.
264
+ sigma (float): Sigma of the Gaussian kernel. Default: 1.6.
265
+
266
+ Returns:
267
+ np.array: The Gaussian kernel.
268
+ """
269
+ from scipy.ndimage import filters as filters
270
+ kernel = np.zeros((kernel_size, kernel_size))
271
+ # set element at the middle to one, a dirac delta
272
+ kernel[kernel_size // 2, kernel_size // 2] = 1
273
+ # gaussian-smooth the dirac, resulting in a gaussian filter
274
+ return filters.gaussian_filter(kernel, sigma)
275
+
276
+
277
+ def duf_downsample(x, kernel_size=13, scale=4):
278
+ """Downsamping with Gaussian kernel used in the DUF official code.
279
+
280
+ Args:
281
+ x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w).
282
+ kernel_size (int): Kernel size. Default: 13.
283
+ scale (int): Downsampling factor. Supported scale: (2, 3, 4).
284
+ Default: 4.
285
+
286
+ Returns:
287
+ Tensor: DUF downsampled frames.
288
+ """
289
+ assert scale in (2, 3, 4), f'Only support scale (2, 3, 4), but got {scale}.'
290
+
291
+ squeeze_flag = False
292
+ if x.ndim == 4:
293
+ squeeze_flag = True
294
+ x = x.unsqueeze(0)
295
+ b, t, c, h, w = x.size()
296
+ x = x.view(-1, 1, h, w)
297
+ pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2
298
+ x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), 'reflect')
299
+
300
+ gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale)
301
+ gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(0).unsqueeze(0)
302
+ x = F.conv2d(x, gaussian_filter, stride=scale)
303
+ x = x[:, :, 2:-2, 2:-2]
304
+ x = x.view(b, t, c, x.size(2), x.size(3))
305
+ if squeeze_flag:
306
+ x = x.squeeze(0)
307
+ return x
308
+
309
+
310
+ def brush_stroke_mask(img, color=(255,255,255)):
311
+ min_num_vertex = 8
312
+ max_num_vertex = 28
313
+ mean_angle = 2*math.pi / 5
314
+ angle_range = 2*math.pi / 12
315
+ # training large mask ratio (training setting)
316
+ min_width = 30
317
+ max_width = 70
318
+ # very large mask ratio (test setting and refine after 200k)
319
+ # min_width = 80
320
+ # max_width = 120
321
+ def generate_mask(H, W, img=None):
322
+ average_radius = math.sqrt(H*H+W*W) / 8
323
+ mask = Image.new('RGB', (W, H), 0)
324
+ if img is not None: mask = img # Image.fromarray(img)
325
+
326
+ for _ in range(np.random.randint(1, 4)):
327
+ num_vertex = np.random.randint(min_num_vertex, max_num_vertex)
328
+ angle_min = mean_angle - np.random.uniform(0, angle_range)
329
+ angle_max = mean_angle + np.random.uniform(0, angle_range)
330
+ angles = []
331
+ vertex = []
332
+ for i in range(num_vertex):
333
+ if i % 2 == 0:
334
+ angles.append(2*math.pi - np.random.uniform(angle_min, angle_max))
335
+ else:
336
+ angles.append(np.random.uniform(angle_min, angle_max))
337
+
338
+ h, w = mask.size
339
+ vertex.append((int(np.random.randint(0, w)), int(np.random.randint(0, h))))
340
+ for i in range(num_vertex):
341
+ r = np.clip(
342
+ np.random.normal(loc=average_radius, scale=average_radius//2),
343
+ 0, 2*average_radius)
344
+ new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w)
345
+ new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h)
346
+ vertex.append((int(new_x), int(new_y)))
347
+
348
+ draw = ImageDraw.Draw(mask)
349
+ width = int(np.random.uniform(min_width, max_width))
350
+ draw.line(vertex, fill=color, width=width)
351
+ for v in vertex:
352
+ draw.ellipse((v[0] - width//2,
353
+ v[1] - width//2,
354
+ v[0] + width//2,
355
+ v[1] + width//2),
356
+ fill=color)
357
+
358
+ return mask
359
+
360
+ width, height = img.size
361
+ mask = generate_mask(height, width, img)
362
+ return mask
363
+
364
+
365
+ def random_ff_mask(shape, max_angle = 10, max_len = 100, max_width = 70, times = 10):
366
+ """Generate a random free form mask with configuration.
367
+ Args:
368
+ config: Config should have configuration including IMG_SHAPES,
369
+ VERTICAL_MARGIN, HEIGHT, HORIZONTAL_MARGIN, WIDTH.
370
+ Returns:
371
+ tuple: (top, left, height, width)
372
+ Link:
373
+ https://github.com/csqiangwen/DeepFillv2_Pytorch/blob/master/train_dataset.py
374
+ """
375
+ height = shape[0]
376
+ width = shape[1]
377
+ mask = np.zeros((height, width), np.float32)
378
+ times = np.random.randint(times-5, times)
379
+ for i in range(times):
380
+ start_x = np.random.randint(width)
381
+ start_y = np.random.randint(height)
382
+ for j in range(1 + np.random.randint(5)):
383
+ angle = 0.01 + np.random.randint(max_angle)
384
+ if i % 2 == 0:
385
+ angle = 2 * 3.1415926 - angle
386
+ length = 10 + np.random.randint(max_len-20, max_len)
387
+ brush_w = 5 + np.random.randint(max_width-30, max_width)
388
+ end_x = (start_x + length * np.sin(angle)).astype(np.int32)
389
+ end_y = (start_y + length * np.cos(angle)).astype(np.int32)
390
+ cv2.line(mask, (start_y, start_x), (end_y, end_x), 1.0, brush_w)
391
+ start_x, start_y = end_x, end_y
392
+ return mask.astype(np.float32)
blissful_tuner/codeformer/basicsr/data/ffhq_blind_dataset.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ import random
4
+ import numpy as np
5
+ import os.path as osp
6
+ from scipy.io import loadmat
7
+ from PIL import Image
8
+ import torch
9
+ import torch.utils.data as data
10
+ from torchvision.transforms.functional import (adjust_brightness, adjust_contrast,
11
+ adjust_hue, adjust_saturation, normalize)
12
+ from codeformer.basicsr.data import gaussian_kernels as gaussian_kernels
13
+ from codeformer.basicsr.data.transforms import augment
14
+ from codeformer.basicsr.data.data_util import paths_from_folder, brush_stroke_mask, random_ff_mask
15
+ from codeformer.basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
16
+ from codeformer.basicsr.utils.registry import DATASET_REGISTRY
17
+
18
+ @DATASET_REGISTRY.register()
19
+ class FFHQBlindDataset(data.Dataset):
20
+
21
+ def __init__(self, opt):
22
+ super(FFHQBlindDataset, self).__init__()
23
+ logger = get_root_logger()
24
+ self.opt = opt
25
+ # file client (io backend)
26
+ self.file_client = None
27
+ self.io_backend_opt = opt['io_backend']
28
+
29
+ self.gt_folder = opt['dataroot_gt']
30
+ self.gt_size = opt.get('gt_size', 512)
31
+ self.in_size = opt.get('in_size', 512)
32
+ assert self.gt_size >= self.in_size, 'Wrong setting.'
33
+
34
+ self.mean = opt.get('mean', [0.5, 0.5, 0.5])
35
+ self.std = opt.get('std', [0.5, 0.5, 0.5])
36
+
37
+ self.component_path = opt.get('component_path', None)
38
+ self.latent_gt_path = opt.get('latent_gt_path', None)
39
+
40
+ if self.component_path is not None:
41
+ self.crop_components = True
42
+ self.components_dict = torch.load(self.component_path)
43
+ self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1.4)
44
+ self.nose_enlarge_ratio = opt.get('nose_enlarge_ratio', 1.1)
45
+ self.mouth_enlarge_ratio = opt.get('mouth_enlarge_ratio', 1.3)
46
+ else:
47
+ self.crop_components = False
48
+
49
+ if self.latent_gt_path is not None:
50
+ self.load_latent_gt = True
51
+ self.latent_gt_dict = torch.load(self.latent_gt_path)
52
+ else:
53
+ self.load_latent_gt = False
54
+
55
+ if self.io_backend_opt['type'] == 'lmdb':
56
+ self.io_backend_opt['db_paths'] = self.gt_folder
57
+ if not self.gt_folder.endswith('.lmdb'):
58
+ raise ValueError("'dataroot_gt' should end with '.lmdb', "f'but received {self.gt_folder}')
59
+ with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
60
+ self.paths = [line.split('.')[0] for line in fin]
61
+ else:
62
+ self.paths = paths_from_folder(self.gt_folder)
63
+
64
+ # inpainting mask
65
+ self.gen_inpaint_mask = opt.get('gen_inpaint_mask', False)
66
+ if self.gen_inpaint_mask:
67
+ logger.info(f'generate mask ...')
68
+ # self.mask_max_angle = opt.get('mask_max_angle', 10)
69
+ # self.mask_max_len = opt.get('mask_max_len', 150)
70
+ # self.mask_max_width = opt.get('mask_max_width', 50)
71
+ # self.mask_draw_times = opt.get('mask_draw_times', 10)
72
+ # # print
73
+ # logger.info(f'mask_max_angle: {self.mask_max_angle}')
74
+ # logger.info(f'mask_max_len: {self.mask_max_len}')
75
+ # logger.info(f'mask_max_width: {self.mask_max_width}')
76
+ # logger.info(f'mask_draw_times: {self.mask_draw_times}')
77
+
78
+ # perform corrupt
79
+ self.use_corrupt = opt.get('use_corrupt', True)
80
+ self.use_motion_kernel = False
81
+ # self.use_motion_kernel = opt.get('use_motion_kernel', True)
82
+
83
+ if self.use_motion_kernel:
84
+ self.motion_kernel_prob = opt.get('motion_kernel_prob', 0.001)
85
+ motion_kernel_path = opt.get('motion_kernel_path', 'basicsr/data/motion-blur-kernels-32.pth')
86
+ self.motion_kernels = torch.load(motion_kernel_path)
87
+
88
+ if self.use_corrupt and not self.gen_inpaint_mask:
89
+ # degradation configurations
90
+ self.blur_kernel_size = opt['blur_kernel_size']
91
+ self.blur_sigma = opt['blur_sigma']
92
+ self.kernel_list = opt['kernel_list']
93
+ self.kernel_prob = opt['kernel_prob']
94
+ self.downsample_range = opt['downsample_range']
95
+ self.noise_range = opt['noise_range']
96
+ self.jpeg_range = opt['jpeg_range']
97
+ # print
98
+ logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, sigma: [{", ".join(map(str, self.blur_sigma))}]')
99
+ logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]')
100
+ logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
101
+ logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]')
102
+
103
+ # color jitter
104
+ self.color_jitter_prob = opt.get('color_jitter_prob', None)
105
+ self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob', None)
106
+ self.color_jitter_shift = opt.get('color_jitter_shift', 20)
107
+ if self.color_jitter_prob is not None:
108
+ logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, shift: {self.color_jitter_shift}')
109
+
110
+ # to gray
111
+ self.gray_prob = opt.get('gray_prob', 0.0)
112
+ if self.gray_prob is not None:
113
+ logger.info(f'Use random gray. Prob: {self.gray_prob}')
114
+ self.color_jitter_shift /= 255.
115
+
116
+ @staticmethod
117
+ def color_jitter(img, shift):
118
+ """jitter color: randomly jitter the RGB values, in numpy formats"""
119
+ jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
120
+ img = img + jitter_val
121
+ img = np.clip(img, 0, 1)
122
+ return img
123
+
124
+ @staticmethod
125
+ def color_jitter_pt(img, brightness, contrast, saturation, hue):
126
+ """jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats"""
127
+ fn_idx = torch.randperm(4)
128
+ for fn_id in fn_idx:
129
+ if fn_id == 0 and brightness is not None:
130
+ brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
131
+ img = adjust_brightness(img, brightness_factor)
132
+
133
+ if fn_id == 1 and contrast is not None:
134
+ contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
135
+ img = adjust_contrast(img, contrast_factor)
136
+
137
+ if fn_id == 2 and saturation is not None:
138
+ saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
139
+ img = adjust_saturation(img, saturation_factor)
140
+
141
+ if fn_id == 3 and hue is not None:
142
+ hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
143
+ img = adjust_hue(img, hue_factor)
144
+ return img
145
+
146
+
147
+ def get_component_locations(self, name, status):
148
+ components_bbox = self.components_dict[name]
149
+ if status[0]: # hflip
150
+ # exchange right and left eye
151
+ tmp = components_bbox['left_eye']
152
+ components_bbox['left_eye'] = components_bbox['right_eye']
153
+ components_bbox['right_eye'] = tmp
154
+ # modify the width coordinate
155
+ components_bbox['left_eye'][0] = self.gt_size - components_bbox['left_eye'][0]
156
+ components_bbox['right_eye'][0] = self.gt_size - components_bbox['right_eye'][0]
157
+ components_bbox['nose'][0] = self.gt_size - components_bbox['nose'][0]
158
+ components_bbox['mouth'][0] = self.gt_size - components_bbox['mouth'][0]
159
+
160
+ locations_gt = {}
161
+ locations_in = {}
162
+ for part in ['left_eye', 'right_eye', 'nose', 'mouth']:
163
+ mean = components_bbox[part][0:2]
164
+ half_len = components_bbox[part][2]
165
+ if 'eye' in part:
166
+ half_len *= self.eye_enlarge_ratio
167
+ elif part == 'nose':
168
+ half_len *= self.nose_enlarge_ratio
169
+ elif part == 'mouth':
170
+ half_len *= self.mouth_enlarge_ratio
171
+ loc = np.hstack((mean - half_len + 1, mean + half_len))
172
+ loc = torch.from_numpy(loc).float()
173
+ locations_gt[part] = loc
174
+ loc_in = loc/(self.gt_size//self.in_size)
175
+ locations_in[part] = loc_in
176
+ return locations_gt, locations_in
177
+
178
+
179
+ def __getitem__(self, index):
180
+ if self.file_client is None:
181
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
182
+
183
+ # load gt image
184
+ gt_path = self.paths[index]
185
+ name = osp.basename(gt_path)[:-4]
186
+ img_bytes = self.file_client.get(gt_path)
187
+ img_gt = imfrombytes(img_bytes, float32=True)
188
+
189
+ # random horizontal flip
190
+ img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True)
191
+
192
+ if self.load_latent_gt:
193
+ if status[0]:
194
+ latent_gt = self.latent_gt_dict['hflip'][name]
195
+ else:
196
+ latent_gt = self.latent_gt_dict['orig'][name]
197
+
198
+ if self.crop_components:
199
+ locations_gt, locations_in = self.get_component_locations(name, status)
200
+
201
+ # generate in image
202
+ img_in = img_gt
203
+ if self.use_corrupt and not self.gen_inpaint_mask:
204
+ # motion blur
205
+ if self.use_motion_kernel and random.random() < self.motion_kernel_prob:
206
+ m_i = random.randint(0,31)
207
+ k = self.motion_kernels[f'{m_i:02d}']
208
+ img_in = cv2.filter2D(img_in,-1,k)
209
+
210
+ # gaussian blur
211
+ kernel = gaussian_kernels.random_mixed_kernels(
212
+ self.kernel_list,
213
+ self.kernel_prob,
214
+ self.blur_kernel_size,
215
+ self.blur_sigma,
216
+ self.blur_sigma,
217
+ [-math.pi, math.pi],
218
+ noise_range=None)
219
+ img_in = cv2.filter2D(img_in, -1, kernel)
220
+
221
+ # downsample
222
+ scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1])
223
+ img_in = cv2.resize(img_in, (int(self.gt_size // scale), int(self.gt_size // scale)), interpolation=cv2.INTER_LINEAR)
224
+
225
+ # noise
226
+ if self.noise_range is not None:
227
+ noise_sigma = np.random.uniform(self.noise_range[0] / 255., self.noise_range[1] / 255.)
228
+ noise = np.float32(np.random.randn(*(img_in.shape))) * noise_sigma
229
+ img_in = img_in + noise
230
+ img_in = np.clip(img_in, 0, 1)
231
+
232
+ # jpeg
233
+ if self.jpeg_range is not None:
234
+ jpeg_p = np.random.uniform(self.jpeg_range[0], self.jpeg_range[1])
235
+ encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), int(jpeg_p)]
236
+ _, encimg = cv2.imencode('.jpg', img_in * 255., encode_param)
237
+ img_in = np.float32(cv2.imdecode(encimg, 1)) / 255.
238
+
239
+ # resize to in_size
240
+ img_in = cv2.resize(img_in, (self.in_size, self.in_size), interpolation=cv2.INTER_LINEAR)
241
+
242
+ # if self.gen_inpaint_mask:
243
+ # inpaint_mask = random_ff_mask(shape=(self.gt_size,self.gt_size),
244
+ # max_angle = self.mask_max_angle, max_len = self.mask_max_len,
245
+ # max_width = self.mask_max_width, times = self.mask_draw_times)
246
+ # img_in = img_in * (1 - inpaint_mask.reshape(self.gt_size,self.gt_size,1)) + \
247
+ # 1.0 * inpaint_mask.reshape(self.gt_size,self.gt_size,1)
248
+
249
+ # inpaint_mask = torch.from_numpy(inpaint_mask).view(1,self.gt_size,self.gt_size)
250
+
251
+ if self.gen_inpaint_mask:
252
+ img_in = (img_in*255).astype('uint8')
253
+ img_in = brush_stroke_mask(Image.fromarray(img_in))
254
+ img_in = np.array(img_in) / 255.
255
+
256
+ # random color jitter (only for lq)
257
+ if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob):
258
+ img_in = self.color_jitter(img_in, self.color_jitter_shift)
259
+ # random to gray (only for lq)
260
+ if self.gray_prob and np.random.uniform() < self.gray_prob:
261
+ img_in = cv2.cvtColor(img_in, cv2.COLOR_BGR2GRAY)
262
+ img_in = np.tile(img_in[:, :, None], [1, 1, 3])
263
+
264
+ # BGR to RGB, HWC to CHW, numpy to tensor
265
+ img_in, img_gt = img2tensor([img_in, img_gt], bgr2rgb=True, float32=True)
266
+
267
+ # random color jitter (pytorch version) (only for lq)
268
+ if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob):
269
+ brightness = self.opt.get('brightness', (0.5, 1.5))
270
+ contrast = self.opt.get('contrast', (0.5, 1.5))
271
+ saturation = self.opt.get('saturation', (0, 1.5))
272
+ hue = self.opt.get('hue', (-0.1, 0.1))
273
+ img_in = self.color_jitter_pt(img_in, brightness, contrast, saturation, hue)
274
+
275
+ # round and clip
276
+ img_in = np.clip((img_in * 255.0).round(), 0, 255) / 255.
277
+
278
+ # Set vgg range_norm=True if use the normalization here
279
+ # normalize
280
+ normalize(img_in, self.mean, self.std, inplace=True)
281
+ normalize(img_gt, self.mean, self.std, inplace=True)
282
+
283
+ return_dict = {'in': img_in, 'gt': img_gt, 'gt_path': gt_path}
284
+
285
+ if self.crop_components:
286
+ return_dict['locations_in'] = locations_in
287
+ return_dict['locations_gt'] = locations_gt
288
+
289
+ if self.load_latent_gt:
290
+ return_dict['latent_gt'] = latent_gt
291
+
292
+ # if self.gen_inpaint_mask:
293
+ # return_dict['inpaint_mask'] = inpaint_mask
294
+
295
+ return return_dict
296
+
297
+
298
+ def __len__(self):
299
+ return len(self.paths)
blissful_tuner/codeformer/basicsr/data/ffhq_blind_joint_dataset.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ import random
4
+ import numpy as np
5
+ import os.path as osp
6
+ from scipy.io import loadmat
7
+ import torch
8
+ import torch.utils.data as data
9
+ from torchvision.transforms.functional import (adjust_brightness, adjust_contrast,
10
+ adjust_hue, adjust_saturation, normalize)
11
+ from codeformer.basicsr.data import gaussian_kernels as gaussian_kernels
12
+ from codeformer.basicsr.data.transforms import augment
13
+ from codeformer.basicsr.data.data_util import paths_from_folder
14
+ from codeformer.basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
15
+ from codeformer.basicsr.utils.registry import DATASET_REGISTRY
16
+
17
+ @DATASET_REGISTRY.register()
18
+ class FFHQBlindJointDataset(data.Dataset):
19
+
20
+ def __init__(self, opt):
21
+ super(FFHQBlindJointDataset, self).__init__()
22
+ logger = get_root_logger()
23
+ self.opt = opt
24
+ # file client (io backend)
25
+ self.file_client = None
26
+ self.io_backend_opt = opt['io_backend']
27
+
28
+ self.gt_folder = opt['dataroot_gt']
29
+ self.gt_size = opt.get('gt_size', 512)
30
+ self.in_size = opt.get('in_size', 512)
31
+ assert self.gt_size >= self.in_size, 'Wrong setting.'
32
+
33
+ self.mean = opt.get('mean', [0.5, 0.5, 0.5])
34
+ self.std = opt.get('std', [0.5, 0.5, 0.5])
35
+
36
+ self.component_path = opt.get('component_path', None)
37
+ self.latent_gt_path = opt.get('latent_gt_path', None)
38
+
39
+ if self.component_path is not None:
40
+ self.crop_components = True
41
+ self.components_dict = torch.load(self.component_path)
42
+ self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1.4)
43
+ self.nose_enlarge_ratio = opt.get('nose_enlarge_ratio', 1.1)
44
+ self.mouth_enlarge_ratio = opt.get('mouth_enlarge_ratio', 1.3)
45
+ else:
46
+ self.crop_components = False
47
+
48
+ if self.latent_gt_path is not None:
49
+ self.load_latent_gt = True
50
+ self.latent_gt_dict = torch.load(self.latent_gt_path)
51
+ else:
52
+ self.load_latent_gt = False
53
+
54
+ if self.io_backend_opt['type'] == 'lmdb':
55
+ self.io_backend_opt['db_paths'] = self.gt_folder
56
+ if not self.gt_folder.endswith('.lmdb'):
57
+ raise ValueError("'dataroot_gt' should end with '.lmdb', "f'but received {self.gt_folder}')
58
+ with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
59
+ self.paths = [line.split('.')[0] for line in fin]
60
+ else:
61
+ self.paths = paths_from_folder(self.gt_folder)
62
+
63
+ # perform corrupt
64
+ self.use_corrupt = opt.get('use_corrupt', True)
65
+ self.use_motion_kernel = False
66
+ # self.use_motion_kernel = opt.get('use_motion_kernel', True)
67
+
68
+ if self.use_motion_kernel:
69
+ self.motion_kernel_prob = opt.get('motion_kernel_prob', 0.001)
70
+ motion_kernel_path = opt.get('motion_kernel_path', 'basicsr/data/motion-blur-kernels-32.pth')
71
+ self.motion_kernels = torch.load(motion_kernel_path)
72
+
73
+ if self.use_corrupt:
74
+ # degradation configurations
75
+ self.blur_kernel_size = self.opt['blur_kernel_size']
76
+ self.kernel_list = self.opt['kernel_list']
77
+ self.kernel_prob = self.opt['kernel_prob']
78
+ # Small degradation
79
+ self.blur_sigma = self.opt['blur_sigma']
80
+ self.downsample_range = self.opt['downsample_range']
81
+ self.noise_range = self.opt['noise_range']
82
+ self.jpeg_range = self.opt['jpeg_range']
83
+ # Large degradation
84
+ self.blur_sigma_large = self.opt['blur_sigma_large']
85
+ self.downsample_range_large = self.opt['downsample_range_large']
86
+ self.noise_range_large = self.opt['noise_range_large']
87
+ self.jpeg_range_large = self.opt['jpeg_range_large']
88
+
89
+ # print
90
+ logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, sigma: [{", ".join(map(str, self.blur_sigma))}]')
91
+ logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]')
92
+ logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
93
+ logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]')
94
+
95
+ # color jitter
96
+ self.color_jitter_prob = opt.get('color_jitter_prob', None)
97
+ self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob', None)
98
+ self.color_jitter_shift = opt.get('color_jitter_shift', 20)
99
+ if self.color_jitter_prob is not None:
100
+ logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, shift: {self.color_jitter_shift}')
101
+
102
+ # to gray
103
+ self.gray_prob = opt.get('gray_prob', 0.0)
104
+ if self.gray_prob is not None:
105
+ logger.info(f'Use random gray. Prob: {self.gray_prob}')
106
+ self.color_jitter_shift /= 255.
107
+
108
+ @staticmethod
109
+ def color_jitter(img, shift):
110
+ """jitter color: randomly jitter the RGB values, in numpy formats"""
111
+ jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
112
+ img = img + jitter_val
113
+ img = np.clip(img, 0, 1)
114
+ return img
115
+
116
+ @staticmethod
117
+ def color_jitter_pt(img, brightness, contrast, saturation, hue):
118
+ """jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats"""
119
+ fn_idx = torch.randperm(4)
120
+ for fn_id in fn_idx:
121
+ if fn_id == 0 and brightness is not None:
122
+ brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
123
+ img = adjust_brightness(img, brightness_factor)
124
+
125
+ if fn_id == 1 and contrast is not None:
126
+ contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
127
+ img = adjust_contrast(img, contrast_factor)
128
+
129
+ if fn_id == 2 and saturation is not None:
130
+ saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
131
+ img = adjust_saturation(img, saturation_factor)
132
+
133
+ if fn_id == 3 and hue is not None:
134
+ hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
135
+ img = adjust_hue(img, hue_factor)
136
+ return img
137
+
138
+
139
+ def get_component_locations(self, name, status):
140
+ components_bbox = self.components_dict[name]
141
+ if status[0]: # hflip
142
+ # exchange right and left eye
143
+ tmp = components_bbox['left_eye']
144
+ components_bbox['left_eye'] = components_bbox['right_eye']
145
+ components_bbox['right_eye'] = tmp
146
+ # modify the width coordinate
147
+ components_bbox['left_eye'][0] = self.gt_size - components_bbox['left_eye'][0]
148
+ components_bbox['right_eye'][0] = self.gt_size - components_bbox['right_eye'][0]
149
+ components_bbox['nose'][0] = self.gt_size - components_bbox['nose'][0]
150
+ components_bbox['mouth'][0] = self.gt_size - components_bbox['mouth'][0]
151
+
152
+ locations_gt = {}
153
+ locations_in = {}
154
+ for part in ['left_eye', 'right_eye', 'nose', 'mouth']:
155
+ mean = components_bbox[part][0:2]
156
+ half_len = components_bbox[part][2]
157
+ if 'eye' in part:
158
+ half_len *= self.eye_enlarge_ratio
159
+ elif part == 'nose':
160
+ half_len *= self.nose_enlarge_ratio
161
+ elif part == 'mouth':
162
+ half_len *= self.mouth_enlarge_ratio
163
+ loc = np.hstack((mean - half_len + 1, mean + half_len))
164
+ loc = torch.from_numpy(loc).float()
165
+ locations_gt[part] = loc
166
+ loc_in = loc/(self.gt_size//self.in_size)
167
+ locations_in[part] = loc_in
168
+ return locations_gt, locations_in
169
+
170
+
171
+ def __getitem__(self, index):
172
+ if self.file_client is None:
173
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
174
+
175
+ # load gt image
176
+ gt_path = self.paths[index]
177
+ name = osp.basename(gt_path)[:-4]
178
+ img_bytes = self.file_client.get(gt_path)
179
+ img_gt = imfrombytes(img_bytes, float32=True)
180
+
181
+ # random horizontal flip
182
+ img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True)
183
+
184
+ if self.load_latent_gt:
185
+ if status[0]:
186
+ latent_gt = self.latent_gt_dict['hflip'][name]
187
+ else:
188
+ latent_gt = self.latent_gt_dict['orig'][name]
189
+
190
+ if self.crop_components:
191
+ locations_gt, locations_in = self.get_component_locations(name, status)
192
+
193
+ # generate in image
194
+ img_in = img_gt
195
+ if self.use_corrupt:
196
+ # motion blur
197
+ if self.use_motion_kernel and random.random() < self.motion_kernel_prob:
198
+ m_i = random.randint(0,31)
199
+ k = self.motion_kernels[f'{m_i:02d}']
200
+ img_in = cv2.filter2D(img_in,-1,k)
201
+
202
+ # gaussian blur
203
+ kernel = gaussian_kernels.random_mixed_kernels(
204
+ self.kernel_list,
205
+ self.kernel_prob,
206
+ self.blur_kernel_size,
207
+ self.blur_sigma,
208
+ self.blur_sigma,
209
+ [-math.pi, math.pi],
210
+ noise_range=None)
211
+ img_in = cv2.filter2D(img_in, -1, kernel)
212
+
213
+ # downsample
214
+ scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1])
215
+ img_in = cv2.resize(img_in, (int(self.gt_size // scale), int(self.gt_size // scale)), interpolation=cv2.INTER_LINEAR)
216
+
217
+ # noise
218
+ if self.noise_range is not None:
219
+ noise_sigma = np.random.uniform(self.noise_range[0] / 255., self.noise_range[1] / 255.)
220
+ noise = np.float32(np.random.randn(*(img_in.shape))) * noise_sigma
221
+ img_in = img_in + noise
222
+ img_in = np.clip(img_in, 0, 1)
223
+
224
+ # jpeg
225
+ if self.jpeg_range is not None:
226
+ jpeg_p = np.random.uniform(self.jpeg_range[0], self.jpeg_range[1])
227
+ encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), int(jpeg_p)]
228
+ _, encimg = cv2.imencode('.jpg', img_in * 255., encode_param)
229
+ img_in = np.float32(cv2.imdecode(encimg, 1)) / 255.
230
+
231
+ # resize to in_size
232
+ img_in = cv2.resize(img_in, (self.in_size, self.in_size), interpolation=cv2.INTER_LINEAR)
233
+
234
+
235
+ # generate in_large with large degradation
236
+ img_in_large = img_gt
237
+
238
+ if self.use_corrupt:
239
+ # motion blur
240
+ if self.use_motion_kernel and random.random() < self.motion_kernel_prob:
241
+ m_i = random.randint(0,31)
242
+ k = self.motion_kernels[f'{m_i:02d}']
243
+ img_in_large = cv2.filter2D(img_in_large,-1,k)
244
+
245
+ # gaussian blur
246
+ kernel = gaussian_kernels.random_mixed_kernels(
247
+ self.kernel_list,
248
+ self.kernel_prob,
249
+ self.blur_kernel_size,
250
+ self.blur_sigma_large,
251
+ self.blur_sigma_large,
252
+ [-math.pi, math.pi],
253
+ noise_range=None)
254
+ img_in_large = cv2.filter2D(img_in_large, -1, kernel)
255
+
256
+ # downsample
257
+ scale = np.random.uniform(self.downsample_range_large[0], self.downsample_range_large[1])
258
+ img_in_large = cv2.resize(img_in_large, (int(self.gt_size // scale), int(self.gt_size // scale)), interpolation=cv2.INTER_LINEAR)
259
+
260
+ # noise
261
+ if self.noise_range_large is not None:
262
+ noise_sigma = np.random.uniform(self.noise_range_large[0] / 255., self.noise_range_large[1] / 255.)
263
+ noise = np.float32(np.random.randn(*(img_in_large.shape))) * noise_sigma
264
+ img_in_large = img_in_large + noise
265
+ img_in_large = np.clip(img_in_large, 0, 1)
266
+
267
+ # jpeg
268
+ if self.jpeg_range_large is not None:
269
+ jpeg_p = np.random.uniform(self.jpeg_range_large[0], self.jpeg_range_large[1])
270
+ encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), int(jpeg_p)]
271
+ _, encimg = cv2.imencode('.jpg', img_in_large * 255., encode_param)
272
+ img_in_large = np.float32(cv2.imdecode(encimg, 1)) / 255.
273
+
274
+ # resize to in_size
275
+ img_in_large = cv2.resize(img_in_large, (self.in_size, self.in_size), interpolation=cv2.INTER_LINEAR)
276
+
277
+
278
+ # random color jitter (only for lq)
279
+ if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob):
280
+ img_in = self.color_jitter(img_in, self.color_jitter_shift)
281
+ img_in_large = self.color_jitter(img_in_large, self.color_jitter_shift)
282
+ # random to gray (only for lq)
283
+ if self.gray_prob and np.random.uniform() < self.gray_prob:
284
+ img_in = cv2.cvtColor(img_in, cv2.COLOR_BGR2GRAY)
285
+ img_in = np.tile(img_in[:, :, None], [1, 1, 3])
286
+ img_in_large = cv2.cvtColor(img_in_large, cv2.COLOR_BGR2GRAY)
287
+ img_in_large = np.tile(img_in_large[:, :, None], [1, 1, 3])
288
+
289
+ # BGR to RGB, HWC to CHW, numpy to tensor
290
+ img_in, img_in_large, img_gt = img2tensor([img_in, img_in_large, img_gt], bgr2rgb=True, float32=True)
291
+
292
+ # random color jitter (pytorch version) (only for lq)
293
+ if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob):
294
+ brightness = self.opt.get('brightness', (0.5, 1.5))
295
+ contrast = self.opt.get('contrast', (0.5, 1.5))
296
+ saturation = self.opt.get('saturation', (0, 1.5))
297
+ hue = self.opt.get('hue', (-0.1, 0.1))
298
+ img_in = self.color_jitter_pt(img_in, brightness, contrast, saturation, hue)
299
+ img_in_large = self.color_jitter_pt(img_in_large, brightness, contrast, saturation, hue)
300
+
301
+ # round and clip
302
+ img_in = np.clip((img_in * 255.0).round(), 0, 255) / 255.
303
+ img_in_large = np.clip((img_in_large * 255.0).round(), 0, 255) / 255.
304
+
305
+ # Set vgg range_norm=True if use the normalization here
306
+ # normalize
307
+ normalize(img_in, self.mean, self.std, inplace=True)
308
+ normalize(img_in_large, self.mean, self.std, inplace=True)
309
+ normalize(img_gt, self.mean, self.std, inplace=True)
310
+
311
+ return_dict = {'in': img_in, 'in_large_de': img_in_large, 'gt': img_gt, 'gt_path': gt_path}
312
+
313
+ if self.crop_components:
314
+ return_dict['locations_in'] = locations_in
315
+ return_dict['locations_gt'] = locations_gt
316
+
317
+ if self.load_latent_gt:
318
+ return_dict['latent_gt'] = latent_gt
319
+
320
+ return return_dict
321
+
322
+
323
+ def __len__(self):
324
+ return len(self.paths)
blissful_tuner/codeformer/basicsr/data/gaussian_kernels.py ADDED
@@ -0,0 +1,690 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import random
4
+ from scipy.ndimage.interpolation import shift
5
+ from scipy.stats import multivariate_normal
6
+
7
+
8
+ def sigma_matrix2(sig_x, sig_y, theta):
9
+ """Calculate the rotated sigma matrix (two dimensional matrix).
10
+ Args:
11
+ sig_x (float):
12
+ sig_y (float):
13
+ theta (float): Radian measurement.
14
+ Returns:
15
+ ndarray: Rotated sigma matrix.
16
+ """
17
+ D = np.array([[sig_x**2, 0], [0, sig_y**2]])
18
+ U = np.array([[np.cos(theta), -np.sin(theta)],
19
+ [np.sin(theta), np.cos(theta)]])
20
+ return np.dot(U, np.dot(D, U.T))
21
+
22
+
23
+ def mesh_grid(kernel_size):
24
+ """Generate the mesh grid, centering at zero.
25
+ Args:
26
+ kernel_size (int):
27
+ Returns:
28
+ xy (ndarray): with the shape (kernel_size, kernel_size, 2)
29
+ xx (ndarray): with the shape (kernel_size, kernel_size)
30
+ yy (ndarray): with the shape (kernel_size, kernel_size)
31
+ """
32
+ ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
33
+ xx, yy = np.meshgrid(ax, ax)
34
+ xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)),
35
+ yy.reshape(kernel_size * kernel_size,
36
+ 1))).reshape(kernel_size, kernel_size, 2)
37
+ return xy, xx, yy
38
+
39
+
40
+ def pdf2(sigma_matrix, grid):
41
+ """Calculate PDF of the bivariate Gaussian distribution.
42
+ Args:
43
+ sigma_matrix (ndarray): with the shape (2, 2)
44
+ grid (ndarray): generated by :func:`mesh_grid`,
45
+ with the shape (K, K, 2), K is the kernel size.
46
+ Returns:
47
+ kernel (ndarrray): un-normalized kernel.
48
+ """
49
+ inverse_sigma = np.linalg.inv(sigma_matrix)
50
+ kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2))
51
+ return kernel
52
+
53
+
54
+ def cdf2(D, grid):
55
+ """Calculate the CDF of the standard bivariate Gaussian distribution.
56
+ Used in skewed Gaussian distribution.
57
+ Args:
58
+ D (ndarrasy): skew matrix.
59
+ grid (ndarray): generated by :func:`mesh_grid`,
60
+ with the shape (K, K, 2), K is the kernel size.
61
+ Returns:
62
+ cdf (ndarray): skewed cdf.
63
+ """
64
+ rv = multivariate_normal([0, 0], [[1, 0], [0, 1]])
65
+ grid = np.dot(grid, D)
66
+ cdf = rv.cdf(grid)
67
+ return cdf
68
+
69
+
70
+ def bivariate_skew_Gaussian(kernel_size, sig_x, sig_y, theta, D, grid=None):
71
+ """Generate a bivariate skew Gaussian kernel.
72
+ Described in `A multivariate skew normal distribution`_ by Shi et. al (2004).
73
+ Args:
74
+ kernel_size (int):
75
+ sig_x (float):
76
+ sig_y (float):
77
+ theta (float): Radian measurement.
78
+ D (ndarrasy): skew matrix.
79
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
80
+ with the shape (K, K, 2), K is the kernel size. Default: None
81
+ Returns:
82
+ kernel (ndarray): normalized kernel.
83
+ .. _A multivariate skew normal distribution:
84
+ https://www.sciencedirect.com/science/article/pii/S0047259X03001313
85
+ """
86
+ if grid is None:
87
+ grid, _, _ = mesh_grid(kernel_size)
88
+ sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
89
+ pdf = pdf2(sigma_matrix, grid)
90
+ cdf = cdf2(D, grid)
91
+ kernel = pdf * cdf
92
+ kernel = kernel / np.sum(kernel)
93
+ return kernel
94
+
95
+
96
+ def mass_center_shift(kernel_size, kernel):
97
+ """Calculate the shift of the mass center of a kenrel.
98
+ Args:
99
+ kernel_size (int):
100
+ kernel (ndarray): normalized kernel.
101
+ Returns:
102
+ delta_h (float):
103
+ delta_w (float):
104
+ """
105
+ ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
106
+ col_sum, row_sum = np.sum(kernel, axis=0), np.sum(kernel, axis=1)
107
+ delta_h = np.dot(row_sum, ax)
108
+ delta_w = np.dot(col_sum, ax)
109
+ return delta_h, delta_w
110
+
111
+
112
+ def bivariate_skew_Gaussian_center(kernel_size,
113
+ sig_x,
114
+ sig_y,
115
+ theta,
116
+ D,
117
+ grid=None):
118
+ """Generate a bivariate skew Gaussian kernel at center. Shift with nearest padding.
119
+ Args:
120
+ kernel_size (int):
121
+ sig_x (float):
122
+ sig_y (float):
123
+ theta (float): Radian measurement.
124
+ D (ndarrasy): skew matrix.
125
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
126
+ with the shape (K, K, 2), K is the kernel size. Default: None
127
+ Returns:
128
+ kernel (ndarray): centered and normalized kernel.
129
+ """
130
+ if grid is None:
131
+ grid, _, _ = mesh_grid(kernel_size)
132
+ kernel = bivariate_skew_Gaussian(kernel_size, sig_x, sig_y, theta, D, grid)
133
+ delta_h, delta_w = mass_center_shift(kernel_size, kernel)
134
+ kernel = shift(kernel, [-delta_h, -delta_w], mode='nearest')
135
+ kernel = kernel / np.sum(kernel)
136
+ return kernel
137
+
138
+
139
+ def bivariate_anisotropic_Gaussian(kernel_size,
140
+ sig_x,
141
+ sig_y,
142
+ theta,
143
+ grid=None):
144
+ """Generate a bivariate anisotropic Gaussian kernel.
145
+ Args:
146
+ kernel_size (int):
147
+ sig_x (float):
148
+ sig_y (float):
149
+ theta (float): Radian measurement.
150
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
151
+ with the shape (K, K, 2), K is the kernel size. Default: None
152
+ Returns:
153
+ kernel (ndarray): normalized kernel.
154
+ """
155
+ if grid is None:
156
+ grid, _, _ = mesh_grid(kernel_size)
157
+ sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
158
+ kernel = pdf2(sigma_matrix, grid)
159
+ kernel = kernel / np.sum(kernel)
160
+ return kernel
161
+
162
+
163
+ def bivariate_isotropic_Gaussian(kernel_size, sig, grid=None):
164
+ """Generate a bivariate isotropic Gaussian kernel.
165
+ Args:
166
+ kernel_size (int):
167
+ sig (float):
168
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
169
+ with the shape (K, K, 2), K is the kernel size. Default: None
170
+ Returns:
171
+ kernel (ndarray): normalized kernel.
172
+ """
173
+ if grid is None:
174
+ grid, _, _ = mesh_grid(kernel_size)
175
+ sigma_matrix = np.array([[sig**2, 0], [0, sig**2]])
176
+ kernel = pdf2(sigma_matrix, grid)
177
+ kernel = kernel / np.sum(kernel)
178
+ return kernel
179
+
180
+
181
+ def bivariate_generalized_Gaussian(kernel_size,
182
+ sig_x,
183
+ sig_y,
184
+ theta,
185
+ beta,
186
+ grid=None):
187
+ """Generate a bivariate generalized Gaussian kernel.
188
+ Described in `Parameter Estimation For Multivariate Generalized Gaussian Distributions`_
189
+ by Pascal et. al (2013).
190
+ Args:
191
+ kernel_size (int):
192
+ sig_x (float):
193
+ sig_y (float):
194
+ theta (float): Radian measurement.
195
+ beta (float): shape parameter, beta = 1 is the normal distribution.
196
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
197
+ with the shape (K, K, 2), K is the kernel size. Default: None
198
+ Returns:
199
+ kernel (ndarray): normalized kernel.
200
+ .. _Parameter Estimation For Multivariate Generalized Gaussian Distributions:
201
+ https://arxiv.org/abs/1302.6498
202
+ """
203
+ if grid is None:
204
+ grid, _, _ = mesh_grid(kernel_size)
205
+ sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
206
+ inverse_sigma = np.linalg.inv(sigma_matrix)
207
+ kernel = np.exp(
208
+ -0.5 * np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta))
209
+ kernel = kernel / np.sum(kernel)
210
+ return kernel
211
+
212
+
213
+ def bivariate_plateau_type1(kernel_size, sig_x, sig_y, theta, beta, grid=None):
214
+ """Generate a plateau-like anisotropic kernel.
215
+ 1 / (1+x^(beta))
216
+ Args:
217
+ kernel_size (int):
218
+ sig_x (float):
219
+ sig_y (float):
220
+ theta (float): Radian measurement.
221
+ beta (float): shape parameter, beta = 1 is the normal distribution.
222
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
223
+ with the shape (K, K, 2), K is the kernel size. Default: None
224
+ Returns:
225
+ kernel (ndarray): normalized kernel.
226
+ """
227
+ if grid is None:
228
+ grid, _, _ = mesh_grid(kernel_size)
229
+ sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
230
+ inverse_sigma = np.linalg.inv(sigma_matrix)
231
+ kernel = np.reciprocal(
232
+ np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1)
233
+ kernel = kernel / np.sum(kernel)
234
+ return kernel
235
+
236
+
237
+ def bivariate_plateau_type1_iso(kernel_size, sig, beta, grid=None):
238
+ """Generate a plateau-like isotropic kernel.
239
+ 1 / (1+x^(beta))
240
+ Args:
241
+ kernel_size (int):
242
+ sig (float):
243
+ beta (float): shape parameter, beta = 1 is the normal distribution.
244
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
245
+ with the shape (K, K, 2), K is the kernel size. Default: None
246
+ Returns:
247
+ kernel (ndarray): normalized kernel.
248
+ """
249
+ if grid is None:
250
+ grid, _, _ = mesh_grid(kernel_size)
251
+ sigma_matrix = np.array([[sig**2, 0], [0, sig**2]])
252
+ inverse_sigma = np.linalg.inv(sigma_matrix)
253
+ kernel = np.reciprocal(
254
+ np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1)
255
+ kernel = kernel / np.sum(kernel)
256
+ return kernel
257
+
258
+
259
+ def random_bivariate_skew_Gaussian_center(kernel_size,
260
+ sigma_x_range,
261
+ sigma_y_range,
262
+ rotation_range,
263
+ noise_range=None,
264
+ strict=False):
265
+ """Randomly generate bivariate skew Gaussian kernels at center.
266
+ Args:
267
+ kernel_size (int):
268
+ sigma_x_range (tuple): [0.6, 5]
269
+ sigma_y_range (tuple): [0.6, 5]
270
+ rotation range (tuple): [-math.pi, math.pi]
271
+ noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
272
+ Returns:
273
+ kernel (ndarray):
274
+ """
275
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
276
+ assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
277
+ assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
278
+ assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
279
+ sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
280
+ sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
281
+ if strict:
282
+ sigma_max = np.max([sigma_x, sigma_y])
283
+ sigma_min = np.min([sigma_x, sigma_y])
284
+ sigma_x, sigma_y = sigma_max, sigma_min
285
+ rotation = np.random.uniform(rotation_range[0], rotation_range[1])
286
+
287
+ sigma_max = np.max([sigma_x, sigma_y])
288
+ thres = 3 / sigma_max
289
+ D = [[np.random.uniform(-thres, thres),
290
+ np.random.uniform(-thres, thres)],
291
+ [np.random.uniform(-thres, thres),
292
+ np.random.uniform(-thres, thres)]]
293
+
294
+ kernel = bivariate_skew_Gaussian_center(kernel_size, sigma_x, sigma_y,
295
+ rotation, D)
296
+
297
+ # add multiplicative noise
298
+ if noise_range is not None:
299
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
300
+ noise = np.random.uniform(
301
+ noise_range[0], noise_range[1], size=kernel.shape)
302
+ kernel = kernel * noise
303
+ kernel = kernel / np.sum(kernel)
304
+ if strict:
305
+ return kernel, sigma_x, sigma_y, rotation, D
306
+ else:
307
+ return kernel
308
+
309
+
310
+ def random_bivariate_anisotropic_Gaussian(kernel_size,
311
+ sigma_x_range,
312
+ sigma_y_range,
313
+ rotation_range,
314
+ noise_range=None,
315
+ strict=False):
316
+ """Randomly generate bivariate anisotropic Gaussian kernels.
317
+ Args:
318
+ kernel_size (int):
319
+ sigma_x_range (tuple): [0.6, 5]
320
+ sigma_y_range (tuple): [0.6, 5]
321
+ rotation range (tuple): [-math.pi, math.pi]
322
+ noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
323
+ Returns:
324
+ kernel (ndarray):
325
+ """
326
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
327
+ assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
328
+ assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
329
+ assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
330
+ sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
331
+ sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
332
+ if strict:
333
+ sigma_max = np.max([sigma_x, sigma_y])
334
+ sigma_min = np.min([sigma_x, sigma_y])
335
+ sigma_x, sigma_y = sigma_max, sigma_min
336
+ rotation = np.random.uniform(rotation_range[0], rotation_range[1])
337
+
338
+ kernel = bivariate_anisotropic_Gaussian(kernel_size, sigma_x, sigma_y,
339
+ rotation)
340
+
341
+ # add multiplicative noise
342
+ if noise_range is not None:
343
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
344
+ noise = np.random.uniform(
345
+ noise_range[0], noise_range[1], size=kernel.shape)
346
+ kernel = kernel * noise
347
+ kernel = kernel / np.sum(kernel)
348
+ if strict:
349
+ return kernel, sigma_x, sigma_y, rotation
350
+ else:
351
+ return kernel
352
+
353
+
354
+ def random_bivariate_isotropic_Gaussian(kernel_size,
355
+ sigma_range,
356
+ noise_range=None,
357
+ strict=False):
358
+ """Randomly generate bivariate isotropic Gaussian kernels.
359
+ Args:
360
+ kernel_size (int):
361
+ sigma_range (tuple): [0.6, 5]
362
+ noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
363
+ Returns:
364
+ kernel (ndarray):
365
+ """
366
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
367
+ assert sigma_range[0] < sigma_range[1], 'Wrong sigma_x_range.'
368
+ sigma = np.random.uniform(sigma_range[0], sigma_range[1])
369
+
370
+ kernel = bivariate_isotropic_Gaussian(kernel_size, sigma)
371
+
372
+ # add multiplicative noise
373
+ if noise_range is not None:
374
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
375
+ noise = np.random.uniform(
376
+ noise_range[0], noise_range[1], size=kernel.shape)
377
+ kernel = kernel * noise
378
+ kernel = kernel / np.sum(kernel)
379
+ if strict:
380
+ return kernel, sigma
381
+ else:
382
+ return kernel
383
+
384
+
385
+ def random_bivariate_generalized_Gaussian(kernel_size,
386
+ sigma_x_range,
387
+ sigma_y_range,
388
+ rotation_range,
389
+ beta_range,
390
+ noise_range=None,
391
+ strict=False):
392
+ """Randomly generate bivariate generalized Gaussian kernels.
393
+ Args:
394
+ kernel_size (int):
395
+ sigma_x_range (tuple): [0.6, 5]
396
+ sigma_y_range (tuple): [0.6, 5]
397
+ rotation range (tuple): [-math.pi, math.pi]
398
+ beta_range (tuple): [0.5, 8]
399
+ noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
400
+ Returns:
401
+ kernel (ndarray):
402
+ """
403
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
404
+ assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
405
+ assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
406
+ assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
407
+ sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
408
+ sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
409
+ if strict:
410
+ sigma_max = np.max([sigma_x, sigma_y])
411
+ sigma_min = np.min([sigma_x, sigma_y])
412
+ sigma_x, sigma_y = sigma_max, sigma_min
413
+ rotation = np.random.uniform(rotation_range[0], rotation_range[1])
414
+ if np.random.uniform() < 0.5:
415
+ beta = np.random.uniform(beta_range[0], 1)
416
+ else:
417
+ beta = np.random.uniform(1, beta_range[1])
418
+
419
+ kernel = bivariate_generalized_Gaussian(kernel_size, sigma_x, sigma_y,
420
+ rotation, beta)
421
+
422
+ # add multiplicative noise
423
+ if noise_range is not None:
424
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
425
+ noise = np.random.uniform(
426
+ noise_range[0], noise_range[1], size=kernel.shape)
427
+ kernel = kernel * noise
428
+ kernel = kernel / np.sum(kernel)
429
+ if strict:
430
+ return kernel, sigma_x, sigma_y, rotation, beta
431
+ else:
432
+ return kernel
433
+
434
+
435
+ def random_bivariate_plateau_type1(kernel_size,
436
+ sigma_x_range,
437
+ sigma_y_range,
438
+ rotation_range,
439
+ beta_range,
440
+ noise_range=None,
441
+ strict=False):
442
+ """Randomly generate bivariate plateau type1 kernels.
443
+ Args:
444
+ kernel_size (int):
445
+ sigma_x_range (tuple): [0.6, 5]
446
+ sigma_y_range (tuple): [0.6, 5]
447
+ rotation range (tuple): [-math.pi/2, math.pi/2]
448
+ beta_range (tuple): [1, 4]
449
+ noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
450
+ Returns:
451
+ kernel (ndarray):
452
+ """
453
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
454
+ assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
455
+ assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
456
+ assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
457
+ sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
458
+ sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
459
+ if strict:
460
+ sigma_max = np.max([sigma_x, sigma_y])
461
+ sigma_min = np.min([sigma_x, sigma_y])
462
+ sigma_x, sigma_y = sigma_max, sigma_min
463
+ rotation = np.random.uniform(rotation_range[0], rotation_range[1])
464
+ if np.random.uniform() < 0.5:
465
+ beta = np.random.uniform(beta_range[0], 1)
466
+ else:
467
+ beta = np.random.uniform(1, beta_range[1])
468
+
469
+ kernel = bivariate_plateau_type1(kernel_size, sigma_x, sigma_y, rotation,
470
+ beta)
471
+
472
+ # add multiplicative noise
473
+ if noise_range is not None:
474
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
475
+ noise = np.random.uniform(
476
+ noise_range[0], noise_range[1], size=kernel.shape)
477
+ kernel = kernel * noise
478
+ kernel = kernel / np.sum(kernel)
479
+ if strict:
480
+ return kernel, sigma_x, sigma_y, rotation, beta
481
+ else:
482
+ return kernel
483
+
484
+
485
+ def random_bivariate_plateau_type1_iso(kernel_size,
486
+ sigma_range,
487
+ beta_range,
488
+ noise_range=None,
489
+ strict=False):
490
+ """Randomly generate bivariate plateau type1 kernels (iso).
491
+ Args:
492
+ kernel_size (int):
493
+ sigma_range (tuple): [0.6, 5]
494
+ beta_range (tuple): [1, 4]
495
+ noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
496
+ Returns:
497
+ kernel (ndarray):
498
+ """
499
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
500
+ assert sigma_range[0] < sigma_range[1], 'Wrong sigma_x_range.'
501
+ sigma = np.random.uniform(sigma_range[0], sigma_range[1])
502
+ beta = np.random.uniform(beta_range[0], beta_range[1])
503
+
504
+ kernel = bivariate_plateau_type1_iso(kernel_size, sigma, beta)
505
+
506
+ # add multiplicative noise
507
+ if noise_range is not None:
508
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
509
+ noise = np.random.uniform(
510
+ noise_range[0], noise_range[1], size=kernel.shape)
511
+ kernel = kernel * noise
512
+ kernel = kernel / np.sum(kernel)
513
+ if strict:
514
+ return kernel, sigma, beta
515
+ else:
516
+ return kernel
517
+
518
+
519
+ def random_mixed_kernels(kernel_list,
520
+ kernel_prob,
521
+ kernel_size=21,
522
+ sigma_x_range=[0.6, 5],
523
+ sigma_y_range=[0.6, 5],
524
+ rotation_range=[-math.pi, math.pi],
525
+ beta_range=[0.5, 8],
526
+ noise_range=None):
527
+ """Randomly generate mixed kernels.
528
+ Args:
529
+ kernel_list (tuple): a list name of kenrel types,
530
+ support ['iso', 'aniso', 'skew', 'generalized', 'plateau_iso', 'plateau_aniso']
531
+ kernel_prob (tuple): corresponding kernel probability for each kernel type
532
+ kernel_size (int):
533
+ sigma_x_range (tuple): [0.6, 5]
534
+ sigma_y_range (tuple): [0.6, 5]
535
+ rotation range (tuple): [-math.pi, math.pi]
536
+ beta_range (tuple): [0.5, 8]
537
+ noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
538
+ Returns:
539
+ kernel (ndarray):
540
+ """
541
+ kernel_type = random.choices(kernel_list, kernel_prob)[0]
542
+ if kernel_type == 'iso':
543
+ kernel = random_bivariate_isotropic_Gaussian(
544
+ kernel_size, sigma_x_range, noise_range=noise_range)
545
+ elif kernel_type == 'aniso':
546
+ kernel = random_bivariate_anisotropic_Gaussian(
547
+ kernel_size,
548
+ sigma_x_range,
549
+ sigma_y_range,
550
+ rotation_range,
551
+ noise_range=noise_range)
552
+ elif kernel_type == 'skew':
553
+ kernel = random_bivariate_skew_Gaussian_center(
554
+ kernel_size,
555
+ sigma_x_range,
556
+ sigma_y_range,
557
+ rotation_range,
558
+ noise_range=noise_range)
559
+ elif kernel_type == 'generalized':
560
+ kernel = random_bivariate_generalized_Gaussian(
561
+ kernel_size,
562
+ sigma_x_range,
563
+ sigma_y_range,
564
+ rotation_range,
565
+ beta_range,
566
+ noise_range=noise_range)
567
+ elif kernel_type == 'plateau_iso':
568
+ kernel = random_bivariate_plateau_type1_iso(
569
+ kernel_size, sigma_x_range, beta_range, noise_range=noise_range)
570
+ elif kernel_type == 'plateau_aniso':
571
+ kernel = random_bivariate_plateau_type1(
572
+ kernel_size,
573
+ sigma_x_range,
574
+ sigma_y_range,
575
+ rotation_range,
576
+ beta_range,
577
+ noise_range=noise_range)
578
+ # add multiplicative noise
579
+ if noise_range is not None:
580
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
581
+ noise = np.random.uniform(
582
+ noise_range[0], noise_range[1], size=kernel.shape)
583
+ kernel = kernel * noise
584
+ kernel = kernel / np.sum(kernel)
585
+ return kernel
586
+
587
+
588
+ def show_one_kernel():
589
+ import matplotlib.pyplot as plt
590
+ kernel_size = 21
591
+
592
+ # bivariate skew Gaussian
593
+ D = [[0, 0], [0, 0]]
594
+ D = [[3 / 4, 0], [0, 0.5]]
595
+ kernel = bivariate_skew_Gaussian_center(kernel_size, 2, 4, -math.pi / 4, D)
596
+ # bivariate anisotropic Gaussian
597
+ kernel = bivariate_anisotropic_Gaussian(kernel_size, 2, 4, -math.pi / 4)
598
+ # bivariate anisotropic Gaussian
599
+ kernel = bivariate_isotropic_Gaussian(kernel_size, 1)
600
+ # bivariate generalized Gaussian
601
+ kernel = bivariate_generalized_Gaussian(
602
+ kernel_size, 2, 4, -math.pi / 4, beta=4)
603
+
604
+ delta_h, delta_w = mass_center_shift(kernel_size, kernel)
605
+ print(delta_h, delta_w)
606
+
607
+ fig, axs = plt.subplots(nrows=2, ncols=2)
608
+ # axs.set_axis_off()
609
+ ax = axs[0][0]
610
+ im = ax.matshow(kernel, cmap='jet', origin='upper')
611
+ fig.colorbar(im, ax=ax)
612
+
613
+ # image
614
+ ax = axs[0][1]
615
+ kernel_vis = kernel - np.min(kernel)
616
+ kernel_vis = kernel_vis / np.max(kernel_vis) * 255.
617
+ ax.imshow(kernel_vis, interpolation='nearest')
618
+
619
+ _, xx, yy = mesh_grid(kernel_size)
620
+ # contour
621
+ ax = axs[1][0]
622
+ CS = ax.contour(xx, yy, kernel, origin='upper')
623
+ ax.clabel(CS, inline=1, fontsize=3)
624
+
625
+ # contourf
626
+ ax = axs[1][1]
627
+ kernel = kernel / np.max(kernel)
628
+ p = ax.contourf(
629
+ xx, yy, kernel, origin='upper', levels=np.linspace(-0.05, 1.05, 10))
630
+ fig.colorbar(p)
631
+
632
+ plt.show()
633
+
634
+
635
+ def show_plateau_kernel():
636
+ import matplotlib.pyplot as plt
637
+ kernel_size = 21
638
+
639
+ kernel = plateau_type1(kernel_size, 2, 4, -math.pi / 8, 2, grid=None)
640
+ kernel_norm = bivariate_isotropic_Gaussian(kernel_size, 5)
641
+ kernel_gau = bivariate_generalized_Gaussian(
642
+ kernel_size, 2, 4, -math.pi / 8, 2, grid=None)
643
+ delta_h, delta_w = mass_center_shift(kernel_size, kernel)
644
+ print(delta_h, delta_w)
645
+
646
+ # kernel_slice = kernel[10, :]
647
+ # kernel_gau_slice = kernel_gau[10, :]
648
+ # kernel_norm_slice = kernel_norm[10, :]
649
+ # fig, ax = plt.subplots()
650
+ # t = list(range(1, 22))
651
+
652
+ # ax.plot(t, kernel_gau_slice)
653
+ # ax.plot(t, kernel_slice)
654
+ # ax.plot(t, kernel_norm_slice)
655
+
656
+ # t = np.arange(0, 10, 0.1)
657
+ # y = np.exp(-0.5 * t)
658
+ # y2 = np.reciprocal(1 + t)
659
+ # print(t.shape)
660
+ # print(y.shape)
661
+ # ax.plot(t, y)
662
+ # ax.plot(t, y2)
663
+ # plt.show()
664
+
665
+ fig, axs = plt.subplots(nrows=2, ncols=2)
666
+ # axs.set_axis_off()
667
+ ax = axs[0][0]
668
+ im = ax.matshow(kernel, cmap='jet', origin='upper')
669
+ fig.colorbar(im, ax=ax)
670
+
671
+ # image
672
+ ax = axs[0][1]
673
+ kernel_vis = kernel - np.min(kernel)
674
+ kernel_vis = kernel_vis / np.max(kernel_vis) * 255.
675
+ ax.imshow(kernel_vis, interpolation='nearest')
676
+
677
+ _, xx, yy = mesh_grid(kernel_size)
678
+ # contour
679
+ ax = axs[1][0]
680
+ CS = ax.contour(xx, yy, kernel, origin='upper')
681
+ ax.clabel(CS, inline=1, fontsize=3)
682
+
683
+ # contourf
684
+ ax = axs[1][1]
685
+ kernel = kernel / np.max(kernel)
686
+ p = ax.contourf(
687
+ xx, yy, kernel, origin='upper', levels=np.linspace(-0.05, 1.05, 10))
688
+ fig.colorbar(p)
689
+
690
+ plt.show()
blissful_tuner/codeformer/basicsr/data/paired_image_dataset.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils import data as data
2
+ from torchvision.transforms.functional import normalize
3
+
4
+ from codeformer.basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file
5
+ from codeformer.basicsr.data.transforms import augment, paired_random_crop
6
+ from codeformer.basicsr.utils import FileClient, imfrombytes, img2tensor
7
+ from codeformer.basicsr.utils.registry import DATASET_REGISTRY
8
+
9
+
10
+ @DATASET_REGISTRY.register()
11
+ class PairedImageDataset(data.Dataset):
12
+ """Paired image dataset for image restoration.
13
+
14
+ Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and
15
+ GT image pairs.
16
+
17
+ There are three modes:
18
+ 1. 'lmdb': Use lmdb files.
19
+ If opt['io_backend'] == lmdb.
20
+ 2. 'meta_info_file': Use meta information file to generate paths.
21
+ If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None.
22
+ 3. 'folder': Scan folders to generate paths.
23
+ The rest.
24
+
25
+ Args:
26
+ opt (dict): Config for train datasets. It contains the following keys:
27
+ dataroot_gt (str): Data root path for gt.
28
+ dataroot_lq (str): Data root path for lq.
29
+ meta_info_file (str): Path for meta information file.
30
+ io_backend (dict): IO backend type and other kwarg.
31
+ filename_tmpl (str): Template for each filename. Note that the
32
+ template excludes the file extension. Default: '{}'.
33
+ gt_size (int): Cropped patched size for gt patches.
34
+ use_flip (bool): Use horizontal flips.
35
+ use_rot (bool): Use rotation (use vertical flip and transposing h
36
+ and w for implementation).
37
+
38
+ scale (bool): Scale, which will be added automatically.
39
+ phase (str): 'train' or 'val'.
40
+ """
41
+
42
+ def __init__(self, opt):
43
+ super(PairedImageDataset, self).__init__()
44
+ self.opt = opt
45
+ # file client (io backend)
46
+ self.file_client = None
47
+ self.io_backend_opt = opt['io_backend']
48
+ self.mean = opt['mean'] if 'mean' in opt else None
49
+ self.std = opt['std'] if 'std' in opt else None
50
+
51
+ self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
52
+ if 'filename_tmpl' in opt:
53
+ self.filename_tmpl = opt['filename_tmpl']
54
+ else:
55
+ self.filename_tmpl = '{}'
56
+
57
+ if self.io_backend_opt['type'] == 'lmdb':
58
+ self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
59
+ self.io_backend_opt['client_keys'] = ['lq', 'gt']
60
+ self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
61
+ elif 'meta_info_file' in self.opt and self.opt['meta_info_file'] is not None:
62
+ self.paths = paired_paths_from_meta_info_file([self.lq_folder, self.gt_folder], ['lq', 'gt'],
63
+ self.opt['meta_info_file'], self.filename_tmpl)
64
+ else:
65
+ self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
66
+
67
+ def __getitem__(self, index):
68
+ if self.file_client is None:
69
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
70
+
71
+ scale = self.opt['scale']
72
+
73
+ # Load gt and lq images. Dimension order: HWC; channel order: BGR;
74
+ # image range: [0, 1], float32.
75
+ gt_path = self.paths[index]['gt_path']
76
+ img_bytes = self.file_client.get(gt_path, 'gt')
77
+ img_gt = imfrombytes(img_bytes, float32=True)
78
+ lq_path = self.paths[index]['lq_path']
79
+ img_bytes = self.file_client.get(lq_path, 'lq')
80
+ img_lq = imfrombytes(img_bytes, float32=True)
81
+
82
+ # augmentation for training
83
+ if self.opt['phase'] == 'train':
84
+ gt_size = self.opt['gt_size']
85
+ # random crop
86
+ img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
87
+ # flip, rotation
88
+ img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_flip'], self.opt['use_rot'])
89
+
90
+ # TODO: color space transform
91
+ # BGR to RGB, HWC to CHW, numpy to tensor
92
+ img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
93
+ # normalize
94
+ if self.mean is not None or self.std is not None:
95
+ normalize(img_lq, self.mean, self.std, inplace=True)
96
+ normalize(img_gt, self.mean, self.std, inplace=True)
97
+
98
+ return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
99
+
100
+ def __len__(self):
101
+ return len(self.paths)
blissful_tuner/codeformer/basicsr/data/prefetch_dataloader.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import queue as Queue
2
+ import threading
3
+ import torch
4
+ from torch.utils.data import DataLoader
5
+
6
+
7
+ class PrefetchGenerator(threading.Thread):
8
+ """A general prefetch generator.
9
+
10
+ Ref:
11
+ https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
12
+
13
+ Args:
14
+ generator: Python generator.
15
+ num_prefetch_queue (int): Number of prefetch queue.
16
+ """
17
+
18
+ def __init__(self, generator, num_prefetch_queue):
19
+ threading.Thread.__init__(self)
20
+ self.queue = Queue.Queue(num_prefetch_queue)
21
+ self.generator = generator
22
+ self.daemon = True
23
+ self.start()
24
+
25
+ def run(self):
26
+ for item in self.generator:
27
+ self.queue.put(item)
28
+ self.queue.put(None)
29
+
30
+ def __next__(self):
31
+ next_item = self.queue.get()
32
+ if next_item is None:
33
+ raise StopIteration
34
+ return next_item
35
+
36
+ def __iter__(self):
37
+ return self
38
+
39
+
40
+ class PrefetchDataLoader(DataLoader):
41
+ """Prefetch version of dataloader.
42
+
43
+ Ref:
44
+ https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
45
+
46
+ TODO:
47
+ Need to test on single gpu and ddp (multi-gpu). There is a known issue in
48
+ ddp.
49
+
50
+ Args:
51
+ num_prefetch_queue (int): Number of prefetch queue.
52
+ kwargs (dict): Other arguments for dataloader.
53
+ """
54
+
55
+ def __init__(self, num_prefetch_queue, **kwargs):
56
+ self.num_prefetch_queue = num_prefetch_queue
57
+ super(PrefetchDataLoader, self).__init__(**kwargs)
58
+
59
+ def __iter__(self):
60
+ return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
61
+
62
+
63
+ class CPUPrefetcher():
64
+ """CPU prefetcher.
65
+
66
+ Args:
67
+ loader: Dataloader.
68
+ """
69
+
70
+ def __init__(self, loader):
71
+ self.ori_loader = loader
72
+ self.loader = iter(loader)
73
+
74
+ def next(self):
75
+ try:
76
+ return next(self.loader)
77
+ except StopIteration:
78
+ return None
79
+
80
+ def reset(self):
81
+ self.loader = iter(self.ori_loader)
82
+
83
+
84
+ class CUDAPrefetcher():
85
+ """CUDA prefetcher.
86
+
87
+ Ref:
88
+ https://github.com/NVIDIA/apex/issues/304#
89
+
90
+ It may consums more GPU memory.
91
+
92
+ Args:
93
+ loader: Dataloader.
94
+ opt (dict): Options.
95
+ """
96
+
97
+ def __init__(self, loader, opt):
98
+ self.ori_loader = loader
99
+ self.loader = iter(loader)
100
+ self.opt = opt
101
+ self.stream = torch.cuda.Stream()
102
+ self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
103
+ self.preload()
104
+
105
+ def preload(self):
106
+ try:
107
+ self.batch = next(self.loader) # self.batch is a dict
108
+ except StopIteration:
109
+ self.batch = None
110
+ return None
111
+ # put tensors to gpu
112
+ with torch.cuda.stream(self.stream):
113
+ for k, v in self.batch.items():
114
+ if torch.is_tensor(v):
115
+ self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
116
+
117
+ def next(self):
118
+ torch.cuda.current_stream().wait_stream(self.stream)
119
+ batch = self.batch
120
+ self.preload()
121
+ return batch
122
+
123
+ def reset(self):
124
+ self.loader = iter(self.ori_loader)
125
+ self.preload()
blissful_tuner/codeformer/basicsr/data/transforms.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import random
3
+
4
+
5
+ def mod_crop(img, scale):
6
+ """Mod crop images, used during testing.
7
+
8
+ Args:
9
+ img (ndarray): Input image.
10
+ scale (int): Scale factor.
11
+
12
+ Returns:
13
+ ndarray: Result image.
14
+ """
15
+ img = img.copy()
16
+ if img.ndim in (2, 3):
17
+ h, w = img.shape[0], img.shape[1]
18
+ h_remainder, w_remainder = h % scale, w % scale
19
+ img = img[:h - h_remainder, :w - w_remainder, ...]
20
+ else:
21
+ raise ValueError(f'Wrong img ndim: {img.ndim}.')
22
+ return img
23
+
24
+
25
+ def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path):
26
+ """Paired random crop.
27
+
28
+ It crops lists of lq and gt images with corresponding locations.
29
+
30
+ Args:
31
+ img_gts (list[ndarray] | ndarray): GT images. Note that all images
32
+ should have the same shape. If the input is an ndarray, it will
33
+ be transformed to a list containing itself.
34
+ img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
35
+ should have the same shape. If the input is an ndarray, it will
36
+ be transformed to a list containing itself.
37
+ gt_patch_size (int): GT patch size.
38
+ scale (int): Scale factor.
39
+ gt_path (str): Path to ground-truth.
40
+
41
+ Returns:
42
+ list[ndarray] | ndarray: GT images and LQ images. If returned results
43
+ only have one element, just return ndarray.
44
+ """
45
+
46
+ if not isinstance(img_gts, list):
47
+ img_gts = [img_gts]
48
+ if not isinstance(img_lqs, list):
49
+ img_lqs = [img_lqs]
50
+
51
+ h_lq, w_lq, _ = img_lqs[0].shape
52
+ h_gt, w_gt, _ = img_gts[0].shape
53
+ lq_patch_size = gt_patch_size // scale
54
+
55
+ if h_gt != h_lq * scale or w_gt != w_lq * scale:
56
+ raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
57
+ f'multiplication of LQ ({h_lq}, {w_lq}).')
58
+ if h_lq < lq_patch_size or w_lq < lq_patch_size:
59
+ raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
60
+ f'({lq_patch_size}, {lq_patch_size}). '
61
+ f'Please remove {gt_path}.')
62
+
63
+ # randomly choose top and left coordinates for lq patch
64
+ top = random.randint(0, h_lq - lq_patch_size)
65
+ left = random.randint(0, w_lq - lq_patch_size)
66
+
67
+ # crop lq patch
68
+ img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
69
+
70
+ # crop corresponding gt patch
71
+ top_gt, left_gt = int(top * scale), int(left * scale)
72
+ img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
73
+ if len(img_gts) == 1:
74
+ img_gts = img_gts[0]
75
+ if len(img_lqs) == 1:
76
+ img_lqs = img_lqs[0]
77
+ return img_gts, img_lqs
78
+
79
+
80
+ def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
81
+ """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
82
+
83
+ We use vertical flip and transpose for rotation implementation.
84
+ All the images in the list use the same augmentation.
85
+
86
+ Args:
87
+ imgs (list[ndarray] | ndarray): Images to be augmented. If the input
88
+ is an ndarray, it will be transformed to a list.
89
+ hflip (bool): Horizontal flip. Default: True.
90
+ rotation (bool): Ratotation. Default: True.
91
+ flows (list[ndarray]: Flows to be augmented. If the input is an
92
+ ndarray, it will be transformed to a list.
93
+ Dimension is (h, w, 2). Default: None.
94
+ return_status (bool): Return the status of flip and rotation.
95
+ Default: False.
96
+
97
+ Returns:
98
+ list[ndarray] | ndarray: Augmented images and flows. If returned
99
+ results only have one element, just return ndarray.
100
+
101
+ """
102
+ hflip = hflip and random.random() < 0.5
103
+ vflip = rotation and random.random() < 0.5
104
+ rot90 = rotation and random.random() < 0.5
105
+
106
+ def _augment(img):
107
+ if hflip: # horizontal
108
+ cv2.flip(img, 1, img)
109
+ if vflip: # vertical
110
+ cv2.flip(img, 0, img)
111
+ if rot90:
112
+ img = img.transpose(1, 0, 2)
113
+ return img
114
+
115
+ def _augment_flow(flow):
116
+ if hflip: # horizontal
117
+ cv2.flip(flow, 1, flow)
118
+ flow[:, :, 0] *= -1
119
+ if vflip: # vertical
120
+ cv2.flip(flow, 0, flow)
121
+ flow[:, :, 1] *= -1
122
+ if rot90:
123
+ flow = flow.transpose(1, 0, 2)
124
+ flow = flow[:, :, [1, 0]]
125
+ return flow
126
+
127
+ if not isinstance(imgs, list):
128
+ imgs = [imgs]
129
+ imgs = [_augment(img) for img in imgs]
130
+ if len(imgs) == 1:
131
+ imgs = imgs[0]
132
+
133
+ if flows is not None:
134
+ if not isinstance(flows, list):
135
+ flows = [flows]
136
+ flows = [_augment_flow(flow) for flow in flows]
137
+ if len(flows) == 1:
138
+ flows = flows[0]
139
+ return imgs, flows
140
+ else:
141
+ if return_status:
142
+ return imgs, (hflip, vflip, rot90)
143
+ else:
144
+ return imgs
145
+
146
+
147
+ def img_rotate(img, angle, center=None, scale=1.0):
148
+ """Rotate image.
149
+
150
+ Args:
151
+ img (ndarray): Image to be rotated.
152
+ angle (float): Rotation angle in degrees. Positive values mean
153
+ counter-clockwise rotation.
154
+ center (tuple[int]): Rotation center. If the center is None,
155
+ initialize it as the center of the image. Default: None.
156
+ scale (float): Isotropic scale factor. Default: 1.0.
157
+ """
158
+ (h, w) = img.shape[:2]
159
+
160
+ if center is None:
161
+ center = (w // 2, h // 2)
162
+
163
+ matrix = cv2.getRotationMatrix2D(center, angle, scale)
164
+ rotated_img = cv2.warpAffine(img, matrix, (w, h))
165
+ return rotated_img
blissful_tuner/codeformer/basicsr/losses/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+
3
+ from codeformer.basicsr.utils import get_root_logger
4
+ from codeformer.basicsr.utils.registry import LOSS_REGISTRY
5
+ from .losses import (CharbonnierLoss, GANLoss, L1Loss, MSELoss, PerceptualLoss, WeightedTVLoss, g_path_regularize,
6
+ gradient_penalty_loss, r1_penalty)
7
+
8
+ __all__ = [
9
+ 'L1Loss', 'MSELoss', 'CharbonnierLoss', 'WeightedTVLoss', 'PerceptualLoss', 'GANLoss', 'gradient_penalty_loss',
10
+ 'r1_penalty', 'g_path_regularize'
11
+ ]
12
+
13
+
14
+ def build_loss(opt):
15
+ """Build loss from options.
16
+
17
+ Args:
18
+ opt (dict): Configuration. It must constain:
19
+ type (str): Model type.
20
+ """
21
+ opt = deepcopy(opt)
22
+ loss_type = opt.pop('type')
23
+ loss = LOSS_REGISTRY.get(loss_type)(**opt)
24
+ logger = get_root_logger()
25
+ logger.info(f'Loss [{loss.__class__.__name__}] is created.')
26
+ return loss
blissful_tuner/codeformer/basicsr/losses/loss_util.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ from torch.nn import functional as F
3
+
4
+
5
+ def reduce_loss(loss, reduction):
6
+ """Reduce loss as specified.
7
+
8
+ Args:
9
+ loss (Tensor): Elementwise loss tensor.
10
+ reduction (str): Options are 'none', 'mean' and 'sum'.
11
+
12
+ Returns:
13
+ Tensor: Reduced loss tensor.
14
+ """
15
+ reduction_enum = F._Reduction.get_enum(reduction)
16
+ # none: 0, elementwise_mean:1, sum: 2
17
+ if reduction_enum == 0:
18
+ return loss
19
+ elif reduction_enum == 1:
20
+ return loss.mean()
21
+ else:
22
+ return loss.sum()
23
+
24
+
25
+ def weight_reduce_loss(loss, weight=None, reduction='mean'):
26
+ """Apply element-wise weight and reduce loss.
27
+
28
+ Args:
29
+ loss (Tensor): Element-wise loss.
30
+ weight (Tensor): Element-wise weights. Default: None.
31
+ reduction (str): Same as built-in losses of PyTorch. Options are
32
+ 'none', 'mean' and 'sum'. Default: 'mean'.
33
+
34
+ Returns:
35
+ Tensor: Loss values.
36
+ """
37
+ # if weight is specified, apply element-wise weight
38
+ if weight is not None:
39
+ assert weight.dim() == loss.dim()
40
+ assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
41
+ loss = loss * weight
42
+
43
+ # if weight is not specified or reduction is sum, just reduce the loss
44
+ if weight is None or reduction == 'sum':
45
+ loss = reduce_loss(loss, reduction)
46
+ # if reduction is mean, then compute mean over weight region
47
+ elif reduction == 'mean':
48
+ if weight.size(1) > 1:
49
+ weight = weight.sum()
50
+ else:
51
+ weight = weight.sum() * loss.size(1)
52
+ loss = loss.sum() / weight
53
+
54
+ return loss
55
+
56
+
57
+ def weighted_loss(loss_func):
58
+ """Create a weighted version of a given loss function.
59
+
60
+ To use this decorator, the loss function must have the signature like
61
+ `loss_func(pred, target, **kwargs)`. The function only needs to compute
62
+ element-wise loss without any reduction. This decorator will add weight
63
+ and reduction arguments to the function. The decorated function will have
64
+ the signature like `loss_func(pred, target, weight=None, reduction='mean',
65
+ **kwargs)`.
66
+
67
+ :Example:
68
+
69
+ >>> import torch
70
+ >>> @weighted_loss
71
+ >>> def l1_loss(pred, target):
72
+ >>> return (pred - target).abs()
73
+
74
+ >>> pred = torch.Tensor([0, 2, 3])
75
+ >>> target = torch.Tensor([1, 1, 1])
76
+ >>> weight = torch.Tensor([1, 0, 1])
77
+
78
+ >>> l1_loss(pred, target)
79
+ tensor(1.3333)
80
+ >>> l1_loss(pred, target, weight)
81
+ tensor(1.5000)
82
+ >>> l1_loss(pred, target, reduction='none')
83
+ tensor([1., 1., 2.])
84
+ >>> l1_loss(pred, target, weight, reduction='sum')
85
+ tensor(3.)
86
+ """
87
+
88
+ @functools.wraps(loss_func)
89
+ def wrapper(pred, target, weight=None, reduction='mean', **kwargs):
90
+ # get element-wise loss
91
+ loss = loss_func(pred, target, **kwargs)
92
+ loss = weight_reduce_loss(loss, weight, reduction)
93
+ return loss
94
+
95
+ return wrapper
blissful_tuner/codeformer/basicsr/losses/losses.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import lpips
3
+ import torch
4
+ from torch import autograd as autograd
5
+ from torch import nn as nn
6
+ from torch.nn import functional as F
7
+
8
+ from codeformer.basicsr.archs.vgg_arch import VGGFeatureExtractor
9
+ from codeformer.basicsr.utils.registry import LOSS_REGISTRY
10
+ from .loss_util import weighted_loss
11
+
12
+ _reduction_modes = ['none', 'mean', 'sum']
13
+
14
+
15
+ @weighted_loss
16
+ def l1_loss(pred, target):
17
+ return F.l1_loss(pred, target, reduction='none')
18
+
19
+
20
+ @weighted_loss
21
+ def mse_loss(pred, target):
22
+ return F.mse_loss(pred, target, reduction='none')
23
+
24
+
25
+ @weighted_loss
26
+ def charbonnier_loss(pred, target, eps=1e-12):
27
+ return torch.sqrt((pred - target)**2 + eps)
28
+
29
+
30
+ @LOSS_REGISTRY.register()
31
+ class L1Loss(nn.Module):
32
+ """L1 (mean absolute error, MAE) loss.
33
+
34
+ Args:
35
+ loss_weight (float): Loss weight for L1 loss. Default: 1.0.
36
+ reduction (str): Specifies the reduction to apply to the output.
37
+ Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
38
+ """
39
+
40
+ def __init__(self, loss_weight=1.0, reduction='mean'):
41
+ super(L1Loss, self).__init__()
42
+ if reduction not in ['none', 'mean', 'sum']:
43
+ raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
44
+
45
+ self.loss_weight = loss_weight
46
+ self.reduction = reduction
47
+
48
+ def forward(self, pred, target, weight=None, **kwargs):
49
+ """
50
+ Args:
51
+ pred (Tensor): of shape (N, C, H, W). Predicted tensor.
52
+ target (Tensor): of shape (N, C, H, W). Ground truth tensor.
53
+ weight (Tensor, optional): of shape (N, C, H, W). Element-wise
54
+ weights. Default: None.
55
+ """
56
+ return self.loss_weight * l1_loss(pred, target, weight, reduction=self.reduction)
57
+
58
+
59
+ @LOSS_REGISTRY.register()
60
+ class MSELoss(nn.Module):
61
+ """MSE (L2) loss.
62
+
63
+ Args:
64
+ loss_weight (float): Loss weight for MSE loss. Default: 1.0.
65
+ reduction (str): Specifies the reduction to apply to the output.
66
+ Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
67
+ """
68
+
69
+ def __init__(self, loss_weight=1.0, reduction='mean'):
70
+ super(MSELoss, self).__init__()
71
+ if reduction not in ['none', 'mean', 'sum']:
72
+ raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
73
+
74
+ self.loss_weight = loss_weight
75
+ self.reduction = reduction
76
+
77
+ def forward(self, pred, target, weight=None, **kwargs):
78
+ """
79
+ Args:
80
+ pred (Tensor): of shape (N, C, H, W). Predicted tensor.
81
+ target (Tensor): of shape (N, C, H, W). Ground truth tensor.
82
+ weight (Tensor, optional): of shape (N, C, H, W). Element-wise
83
+ weights. Default: None.
84
+ """
85
+ return self.loss_weight * mse_loss(pred, target, weight, reduction=self.reduction)
86
+
87
+
88
+ @LOSS_REGISTRY.register()
89
+ class CharbonnierLoss(nn.Module):
90
+ """Charbonnier loss (one variant of Robust L1Loss, a differentiable
91
+ variant of L1Loss).
92
+
93
+ Described in "Deep Laplacian Pyramid Networks for Fast and Accurate
94
+ Super-Resolution".
95
+
96
+ Args:
97
+ loss_weight (float): Loss weight for L1 loss. Default: 1.0.
98
+ reduction (str): Specifies the reduction to apply to the output.
99
+ Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
100
+ eps (float): A value used to control the curvature near zero.
101
+ Default: 1e-12.
102
+ """
103
+
104
+ def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-12):
105
+ super(CharbonnierLoss, self).__init__()
106
+ if reduction not in ['none', 'mean', 'sum']:
107
+ raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
108
+
109
+ self.loss_weight = loss_weight
110
+ self.reduction = reduction
111
+ self.eps = eps
112
+
113
+ def forward(self, pred, target, weight=None, **kwargs):
114
+ """
115
+ Args:
116
+ pred (Tensor): of shape (N, C, H, W). Predicted tensor.
117
+ target (Tensor): of shape (N, C, H, W). Ground truth tensor.
118
+ weight (Tensor, optional): of shape (N, C, H, W). Element-wise
119
+ weights. Default: None.
120
+ """
121
+ return self.loss_weight * charbonnier_loss(pred, target, weight, eps=self.eps, reduction=self.reduction)
122
+
123
+
124
+ @LOSS_REGISTRY.register()
125
+ class WeightedTVLoss(L1Loss):
126
+ """Weighted TV loss.
127
+
128
+ Args:
129
+ loss_weight (float): Loss weight. Default: 1.0.
130
+ """
131
+
132
+ def __init__(self, loss_weight=1.0):
133
+ super(WeightedTVLoss, self).__init__(loss_weight=loss_weight)
134
+
135
+ def forward(self, pred, weight=None):
136
+ y_diff = super(WeightedTVLoss, self).forward(pred[:, :, :-1, :], pred[:, :, 1:, :], weight=weight[:, :, :-1, :])
137
+ x_diff = super(WeightedTVLoss, self).forward(pred[:, :, :, :-1], pred[:, :, :, 1:], weight=weight[:, :, :, :-1])
138
+
139
+ loss = x_diff + y_diff
140
+
141
+ return loss
142
+
143
+
144
+ @LOSS_REGISTRY.register()
145
+ class PerceptualLoss(nn.Module):
146
+ """Perceptual loss with commonly used style loss.
147
+
148
+ Args:
149
+ layer_weights (dict): The weight for each layer of vgg feature.
150
+ Here is an example: {'conv5_4': 1.}, which means the conv5_4
151
+ feature layer (before relu5_4) will be extracted with weight
152
+ 1.0 in calculting losses.
153
+ vgg_type (str): The type of vgg network used as feature extractor.
154
+ Default: 'vgg19'.
155
+ use_input_norm (bool): If True, normalize the input image in vgg.
156
+ Default: True.
157
+ range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
158
+ Default: False.
159
+ perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
160
+ loss will be calculated and the loss will multiplied by the
161
+ weight. Default: 1.0.
162
+ style_weight (float): If `style_weight > 0`, the style loss will be
163
+ calculated and the loss will multiplied by the weight.
164
+ Default: 0.
165
+ criterion (str): Criterion used for perceptual loss. Default: 'l1'.
166
+ """
167
+
168
+ def __init__(self,
169
+ layer_weights,
170
+ vgg_type='vgg19',
171
+ use_input_norm=True,
172
+ range_norm=False,
173
+ perceptual_weight=1.0,
174
+ style_weight=0.,
175
+ criterion='l1'):
176
+ super(PerceptualLoss, self).__init__()
177
+ self.perceptual_weight = perceptual_weight
178
+ self.style_weight = style_weight
179
+ self.layer_weights = layer_weights
180
+ self.vgg = VGGFeatureExtractor(
181
+ layer_name_list=list(layer_weights.keys()),
182
+ vgg_type=vgg_type,
183
+ use_input_norm=use_input_norm,
184
+ range_norm=range_norm)
185
+
186
+ self.criterion_type = criterion
187
+ if self.criterion_type == 'l1':
188
+ self.criterion = torch.nn.L1Loss()
189
+ elif self.criterion_type == 'l2':
190
+ self.criterion = torch.nn.L2loss()
191
+ elif self.criterion_type == 'mse':
192
+ self.criterion = torch.nn.MSELoss(reduction='mean')
193
+ elif self.criterion_type == 'fro':
194
+ self.criterion = None
195
+ else:
196
+ raise NotImplementedError(f'{criterion} criterion has not been supported.')
197
+
198
+ def forward(self, x, gt):
199
+ """Forward function.
200
+
201
+ Args:
202
+ x (Tensor): Input tensor with shape (n, c, h, w).
203
+ gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
204
+
205
+ Returns:
206
+ Tensor: Forward results.
207
+ """
208
+ # extract vgg features
209
+ x_features = self.vgg(x)
210
+ gt_features = self.vgg(gt.detach())
211
+
212
+ # calculate perceptual loss
213
+ if self.perceptual_weight > 0:
214
+ percep_loss = 0
215
+ for k in x_features.keys():
216
+ if self.criterion_type == 'fro':
217
+ percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k]
218
+ else:
219
+ percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
220
+ percep_loss *= self.perceptual_weight
221
+ else:
222
+ percep_loss = None
223
+
224
+ # calculate style loss
225
+ if self.style_weight > 0:
226
+ style_loss = 0
227
+ for k in x_features.keys():
228
+ if self.criterion_type == 'fro':
229
+ style_loss += torch.norm(
230
+ self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k]
231
+ else:
232
+ style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat(
233
+ gt_features[k])) * self.layer_weights[k]
234
+ style_loss *= self.style_weight
235
+ else:
236
+ style_loss = None
237
+
238
+ return percep_loss, style_loss
239
+
240
+ def _gram_mat(self, x):
241
+ """Calculate Gram matrix.
242
+
243
+ Args:
244
+ x (torch.Tensor): Tensor with shape of (n, c, h, w).
245
+
246
+ Returns:
247
+ torch.Tensor: Gram matrix.
248
+ """
249
+ n, c, h, w = x.size()
250
+ features = x.view(n, c, w * h)
251
+ features_t = features.transpose(1, 2)
252
+ gram = features.bmm(features_t) / (c * h * w)
253
+ return gram
254
+
255
+
256
+ @LOSS_REGISTRY.register()
257
+ class LPIPSLoss(nn.Module):
258
+ def __init__(self,
259
+ loss_weight=1.0,
260
+ use_input_norm=True,
261
+ range_norm=False,):
262
+ super(LPIPSLoss, self).__init__()
263
+ self.perceptual = lpips.LPIPS(net="vgg", spatial=False).eval()
264
+ self.loss_weight = loss_weight
265
+ self.use_input_norm = use_input_norm
266
+ self.range_norm = range_norm
267
+
268
+ if self.use_input_norm:
269
+ # the mean is for image with range [0, 1]
270
+ self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
271
+ # the std is for image with range [0, 1]
272
+ self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
273
+
274
+ def forward(self, pred, target):
275
+ if self.range_norm:
276
+ pred = (pred + 1) / 2
277
+ target = (target + 1) / 2
278
+ if self.use_input_norm:
279
+ pred = (pred - self.mean) / self.std
280
+ target = (target - self.mean) / self.std
281
+ lpips_loss = self.perceptual(target.contiguous(), pred.contiguous())
282
+ return self.loss_weight * lpips_loss.mean()
283
+
284
+
285
+ @LOSS_REGISTRY.register()
286
+ class GANLoss(nn.Module):
287
+ """Define GAN loss.
288
+
289
+ Args:
290
+ gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'.
291
+ real_label_val (float): The value for real label. Default: 1.0.
292
+ fake_label_val (float): The value for fake label. Default: 0.0.
293
+ loss_weight (float): Loss weight. Default: 1.0.
294
+ Note that loss_weight is only for generators; and it is always 1.0
295
+ for discriminators.
296
+ """
297
+
298
+ def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):
299
+ super(GANLoss, self).__init__()
300
+ self.gan_type = gan_type
301
+ self.loss_weight = loss_weight
302
+ self.real_label_val = real_label_val
303
+ self.fake_label_val = fake_label_val
304
+
305
+ if self.gan_type == 'vanilla':
306
+ self.loss = nn.BCEWithLogitsLoss()
307
+ elif self.gan_type == 'lsgan':
308
+ self.loss = nn.MSELoss()
309
+ elif self.gan_type == 'wgan':
310
+ self.loss = self._wgan_loss
311
+ elif self.gan_type == 'wgan_softplus':
312
+ self.loss = self._wgan_softplus_loss
313
+ elif self.gan_type == 'hinge':
314
+ self.loss = nn.ReLU()
315
+ else:
316
+ raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.')
317
+
318
+ def _wgan_loss(self, input, target):
319
+ """wgan loss.
320
+
321
+ Args:
322
+ input (Tensor): Input tensor.
323
+ target (bool): Target label.
324
+
325
+ Returns:
326
+ Tensor: wgan loss.
327
+ """
328
+ return -input.mean() if target else input.mean()
329
+
330
+ def _wgan_softplus_loss(self, input, target):
331
+ """wgan loss with soft plus. softplus is a smooth approximation to the
332
+ ReLU function.
333
+
334
+ In StyleGAN2, it is called:
335
+ Logistic loss for discriminator;
336
+ Non-saturating loss for generator.
337
+
338
+ Args:
339
+ input (Tensor): Input tensor.
340
+ target (bool): Target label.
341
+
342
+ Returns:
343
+ Tensor: wgan loss.
344
+ """
345
+ return F.softplus(-input).mean() if target else F.softplus(input).mean()
346
+
347
+ def get_target_label(self, input, target_is_real):
348
+ """Get target label.
349
+
350
+ Args:
351
+ input (Tensor): Input tensor.
352
+ target_is_real (bool): Whether the target is real or fake.
353
+
354
+ Returns:
355
+ (bool | Tensor): Target tensor. Return bool for wgan, otherwise,
356
+ return Tensor.
357
+ """
358
+
359
+ if self.gan_type in ['wgan', 'wgan_softplus']:
360
+ return target_is_real
361
+ target_val = (self.real_label_val if target_is_real else self.fake_label_val)
362
+ return input.new_ones(input.size()) * target_val
363
+
364
+ def forward(self, input, target_is_real, is_disc=False):
365
+ """
366
+ Args:
367
+ input (Tensor): The input for the loss module, i.e., the network
368
+ prediction.
369
+ target_is_real (bool): Whether the targe is real or fake.
370
+ is_disc (bool): Whether the loss for discriminators or not.
371
+ Default: False.
372
+
373
+ Returns:
374
+ Tensor: GAN loss value.
375
+ """
376
+ if self.gan_type == 'hinge':
377
+ if is_disc: # for discriminators in hinge-gan
378
+ input = -input if target_is_real else input
379
+ loss = self.loss(1 + input).mean()
380
+ else: # for generators in hinge-gan
381
+ loss = -input.mean()
382
+ else: # other gan types
383
+ target_label = self.get_target_label(input, target_is_real)
384
+ loss = self.loss(input, target_label)
385
+
386
+ # loss_weight is always 1.0 for discriminators
387
+ return loss if is_disc else loss * self.loss_weight
388
+
389
+
390
+ def r1_penalty(real_pred, real_img):
391
+ """R1 regularization for discriminator. The core idea is to
392
+ penalize the gradient on real data alone: when the
393
+ generator distribution produces the true data distribution
394
+ and the discriminator is equal to 0 on the data manifold, the
395
+ gradient penalty ensures that the discriminator cannot create
396
+ a non-zero gradient orthogonal to the data manifold without
397
+ suffering a loss in the GAN game.
398
+
399
+ Ref:
400
+ Eq. 9 in Which training methods for GANs do actually converge.
401
+ """
402
+ grad_real = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0]
403
+ grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean()
404
+ return grad_penalty
405
+
406
+
407
+ def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
408
+ noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3])
409
+ grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0]
410
+ path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
411
+
412
+ path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
413
+
414
+ path_penalty = (path_lengths - path_mean).pow(2).mean()
415
+
416
+ return path_penalty, path_lengths.detach().mean(), path_mean.detach()
417
+
418
+
419
+ def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None):
420
+ """Calculate gradient penalty for wgan-gp.
421
+
422
+ Args:
423
+ discriminator (nn.Module): Network for the discriminator.
424
+ real_data (Tensor): Real input data.
425
+ fake_data (Tensor): Fake input data.
426
+ weight (Tensor): Weight tensor. Default: None.
427
+
428
+ Returns:
429
+ Tensor: A tensor for gradient penalty.
430
+ """
431
+
432
+ batch_size = real_data.size(0)
433
+ alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1))
434
+
435
+ # interpolate between real_data and fake_data
436
+ interpolates = alpha * real_data + (1. - alpha) * fake_data
437
+ interpolates = autograd.Variable(interpolates, requires_grad=True)
438
+
439
+ disc_interpolates = discriminator(interpolates)
440
+ gradients = autograd.grad(
441
+ outputs=disc_interpolates,
442
+ inputs=interpolates,
443
+ grad_outputs=torch.ones_like(disc_interpolates),
444
+ create_graph=True,
445
+ retain_graph=True,
446
+ only_inputs=True)[0]
447
+
448
+ if weight is not None:
449
+ gradients = gradients * weight
450
+
451
+ gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean()
452
+ if weight is not None:
453
+ gradients_penalty /= torch.mean(weight)
454
+
455
+ return gradients_penalty
blissful_tuner/codeformer/basicsr/metrics/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+
3
+ from codeformer.basicsr.utils.registry import METRIC_REGISTRY
4
+ from .psnr_ssim import calculate_psnr, calculate_ssim
5
+
6
+ __all__ = ['calculate_psnr', 'calculate_ssim']
7
+
8
+
9
+ def calculate_metric(data, opt):
10
+ """Calculate metric from data and options.
11
+
12
+ Args:
13
+ opt (dict): Configuration. It must constain:
14
+ type (str): Model type.
15
+ """
16
+ opt = deepcopy(opt)
17
+ metric_type = opt.pop('type')
18
+ metric = METRIC_REGISTRY.get(metric_type)(**data, **opt)
19
+ return metric
blissful_tuner/codeformer/basicsr/metrics/metric_util.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from codeformer.basicsr.utils.matlab_functions import bgr2ycbcr
4
+
5
+
6
+ def reorder_image(img, input_order='HWC'):
7
+ """Reorder images to 'HWC' order.
8
+
9
+ If the input_order is (h, w), return (h, w, 1);
10
+ If the input_order is (c, h, w), return (h, w, c);
11
+ If the input_order is (h, w, c), return as it is.
12
+
13
+ Args:
14
+ img (ndarray): Input image.
15
+ input_order (str): Whether the input order is 'HWC' or 'CHW'.
16
+ If the input image shape is (h, w), input_order will not have
17
+ effects. Default: 'HWC'.
18
+
19
+ Returns:
20
+ ndarray: reordered image.
21
+ """
22
+
23
+ if input_order not in ['HWC', 'CHW']:
24
+ raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' "'HWC' and 'CHW'")
25
+ if len(img.shape) == 2:
26
+ img = img[..., None]
27
+ if input_order == 'CHW':
28
+ img = img.transpose(1, 2, 0)
29
+ return img
30
+
31
+
32
+ def to_y_channel(img):
33
+ """Change to Y channel of YCbCr.
34
+
35
+ Args:
36
+ img (ndarray): Images with range [0, 255].
37
+
38
+ Returns:
39
+ (ndarray): Images with range [0, 255] (float type) without round.
40
+ """
41
+ img = img.astype(np.float32) / 255.
42
+ if img.ndim == 3 and img.shape[2] == 3:
43
+ img = bgr2ycbcr(img, y_only=True)
44
+ img = img[..., None]
45
+ return img * 255.
blissful_tuner/codeformer/basicsr/metrics/psnr_ssim.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+ from codeformer.basicsr.metrics.metric_util import reorder_image, to_y_channel
5
+ from codeformer.basicsr.utils.registry import METRIC_REGISTRY
6
+
7
+
8
+ @METRIC_REGISTRY.register()
9
+ def calculate_psnr(img1, img2, crop_border, input_order='HWC', test_y_channel=False):
10
+ """Calculate PSNR (Peak Signal-to-Noise Ratio).
11
+
12
+ Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
13
+
14
+ Args:
15
+ img1 (ndarray): Images with range [0, 255].
16
+ img2 (ndarray): Images with range [0, 255].
17
+ crop_border (int): Cropped pixels in each edge of an image. These
18
+ pixels are not involved in the PSNR calculation.
19
+ input_order (str): Whether the input order is 'HWC' or 'CHW'.
20
+ Default: 'HWC'.
21
+ test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
22
+
23
+ Returns:
24
+ float: psnr result.
25
+ """
26
+
27
+ assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
28
+ if input_order not in ['HWC', 'CHW']:
29
+ raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"')
30
+ img1 = reorder_image(img1, input_order=input_order)
31
+ img2 = reorder_image(img2, input_order=input_order)
32
+ img1 = img1.astype(np.float64)
33
+ img2 = img2.astype(np.float64)
34
+
35
+ if crop_border != 0:
36
+ img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
37
+ img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
38
+
39
+ if test_y_channel:
40
+ img1 = to_y_channel(img1)
41
+ img2 = to_y_channel(img2)
42
+
43
+ mse = np.mean((img1 - img2)**2)
44
+ if mse == 0:
45
+ return float('inf')
46
+ return 20. * np.log10(255. / np.sqrt(mse))
47
+
48
+
49
+ def _ssim(img1, img2):
50
+ """Calculate SSIM (structural similarity) for one channel images.
51
+
52
+ It is called by func:`calculate_ssim`.
53
+
54
+ Args:
55
+ img1 (ndarray): Images with range [0, 255] with order 'HWC'.
56
+ img2 (ndarray): Images with range [0, 255] with order 'HWC'.
57
+
58
+ Returns:
59
+ float: ssim result.
60
+ """
61
+
62
+ C1 = (0.01 * 255)**2
63
+ C2 = (0.03 * 255)**2
64
+
65
+ img1 = img1.astype(np.float64)
66
+ img2 = img2.astype(np.float64)
67
+ kernel = cv2.getGaussianKernel(11, 1.5)
68
+ window = np.outer(kernel, kernel.transpose())
69
+
70
+ mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]
71
+ mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
72
+ mu1_sq = mu1**2
73
+ mu2_sq = mu2**2
74
+ mu1_mu2 = mu1 * mu2
75
+ sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
76
+ sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
77
+ sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
78
+
79
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
80
+ return ssim_map.mean()
81
+
82
+
83
+ @METRIC_REGISTRY.register()
84
+ def calculate_ssim(img1, img2, crop_border, input_order='HWC', test_y_channel=False):
85
+ """Calculate SSIM (structural similarity).
86
+
87
+ Ref:
88
+ Image quality assessment: From error visibility to structural similarity
89
+
90
+ The results are the same as that of the official released MATLAB code in
91
+ https://ece.uwaterloo.ca/~z70wang/research/ssim/.
92
+
93
+ For three-channel images, SSIM is calculated for each channel and then
94
+ averaged.
95
+
96
+ Args:
97
+ img1 (ndarray): Images with range [0, 255].
98
+ img2 (ndarray): Images with range [0, 255].
99
+ crop_border (int): Cropped pixels in each edge of an image. These
100
+ pixels are not involved in the SSIM calculation.
101
+ input_order (str): Whether the input order is 'HWC' or 'CHW'.
102
+ Default: 'HWC'.
103
+ test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
104
+
105
+ Returns:
106
+ float: ssim result.
107
+ """
108
+
109
+ assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
110
+ if input_order not in ['HWC', 'CHW']:
111
+ raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"')
112
+ img1 = reorder_image(img1, input_order=input_order)
113
+ img2 = reorder_image(img2, input_order=input_order)
114
+ img1 = img1.astype(np.float64)
115
+ img2 = img2.astype(np.float64)
116
+
117
+ if crop_border != 0:
118
+ img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
119
+ img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
120
+
121
+ if test_y_channel:
122
+ img1 = to_y_channel(img1)
123
+ img2 = to_y_channel(img2)
124
+
125
+ ssims = []
126
+ for i in range(img1.shape[2]):
127
+ ssims.append(_ssim(img1[..., i], img2[..., i]))
128
+ return np.array(ssims).mean()
blissful_tuner/codeformer/basicsr/models/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from copy import deepcopy
3
+ from os import path as osp
4
+
5
+ from codeformer.basicsr.utils import get_root_logger, scandir
6
+ from codeformer.basicsr.utils.registry import MODEL_REGISTRY
7
+
8
+ __all__ = ['build_model']
9
+
10
+ # automatically scan and import model modules for registry
11
+ # scan all the files under the 'models' folder and collect files ending with
12
+ # '_model.py'
13
+ model_folder = osp.dirname(osp.abspath(__file__))
14
+ model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
15
+ # import all the model modules
16
+ _model_modules = [importlib.import_module(f'codeformer.basicsr.models.{file_name}') for file_name in model_filenames]
17
+
18
+
19
+ def build_model(opt):
20
+ """Build model from options.
21
+
22
+ Args:
23
+ opt (dict): Configuration. It must constain:
24
+ model_type (str): Model type.
25
+ """
26
+ opt = deepcopy(opt)
27
+ model = MODEL_REGISTRY.get(opt['model_type'])(opt)
28
+ logger = get_root_logger()
29
+ logger.info(f'Model [{model.__class__.__name__}] is created.')
30
+ return model
blissful_tuner/codeformer/basicsr/models/base_model.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import torch
4
+ from collections import OrderedDict
5
+ from copy import deepcopy
6
+ from torch.nn.parallel import DataParallel, DistributedDataParallel
7
+
8
+ from codeformer.basicsr.models import lr_scheduler as lr_scheduler
9
+ from codeformer.basicsr.utils.dist_util import master_only
10
+
11
+ logger = logging.getLogger('basicsr')
12
+
13
+
14
+ class BaseModel():
15
+ """Base model."""
16
+
17
+ def __init__(self, opt):
18
+ self.opt = opt
19
+ self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
20
+ self.is_train = opt['is_train']
21
+ self.schedulers = []
22
+ self.optimizers = []
23
+
24
+ def feed_data(self, data):
25
+ pass
26
+
27
+ def optimize_parameters(self):
28
+ pass
29
+
30
+ def get_current_visuals(self):
31
+ pass
32
+
33
+ def save(self, epoch, current_iter):
34
+ """Save networks and training state."""
35
+ pass
36
+
37
+ def validation(self, dataloader, current_iter, tb_logger, save_img=False):
38
+ """Validation function.
39
+
40
+ Args:
41
+ dataloader (torch.utils.data.DataLoader): Validation dataloader.
42
+ current_iter (int): Current iteration.
43
+ tb_logger (tensorboard logger): Tensorboard logger.
44
+ save_img (bool): Whether to save images. Default: False.
45
+ """
46
+ if self.opt['dist']:
47
+ self.dist_validation(dataloader, current_iter, tb_logger, save_img)
48
+ else:
49
+ self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
50
+
51
+ def model_ema(self, decay=0.999):
52
+ net_g = self.get_bare_model(self.net_g)
53
+
54
+ net_g_params = dict(net_g.named_parameters())
55
+ net_g_ema_params = dict(self.net_g_ema.named_parameters())
56
+
57
+ for k in net_g_ema_params.keys():
58
+ net_g_ema_params[k].data.mul_(decay).add_(net_g_params[k].data, alpha=1 - decay)
59
+
60
+ def get_current_log(self):
61
+ return self.log_dict
62
+
63
+ def model_to_device(self, net):
64
+ """Model to device. It also warps models with DistributedDataParallel
65
+ or DataParallel.
66
+
67
+ Args:
68
+ net (nn.Module)
69
+ """
70
+ net = net.to(self.device)
71
+ if self.opt['dist']:
72
+ find_unused_parameters = self.opt.get('find_unused_parameters', False)
73
+ net = DistributedDataParallel(
74
+ net, device_ids=[torch.cuda.current_device()], find_unused_parameters=find_unused_parameters)
75
+ elif self.opt['num_gpu'] > 1:
76
+ net = DataParallel(net)
77
+ return net
78
+
79
+ def get_optimizer(self, optim_type, params, lr, **kwargs):
80
+ if optim_type == 'Adam':
81
+ optimizer = torch.optim.Adam(params, lr, **kwargs)
82
+ else:
83
+ raise NotImplementedError(f'optimizer {optim_type} is not supperted yet.')
84
+ return optimizer
85
+
86
+ def setup_schedulers(self):
87
+ """Set up schedulers."""
88
+ train_opt = self.opt['train']
89
+ scheduler_type = train_opt['scheduler'].pop('type')
90
+ if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']:
91
+ for optimizer in self.optimizers:
92
+ self.schedulers.append(lr_scheduler.MultiStepRestartLR(optimizer, **train_opt['scheduler']))
93
+ elif scheduler_type == 'CosineAnnealingRestartLR':
94
+ for optimizer in self.optimizers:
95
+ self.schedulers.append(lr_scheduler.CosineAnnealingRestartLR(optimizer, **train_opt['scheduler']))
96
+ else:
97
+ raise NotImplementedError(f'Scheduler {scheduler_type} is not implemented yet.')
98
+
99
+ def get_bare_model(self, net):
100
+ """Get bare model, especially under wrapping with
101
+ DistributedDataParallel or DataParallel.
102
+ """
103
+ if isinstance(net, (DataParallel, DistributedDataParallel)):
104
+ net = net.module
105
+ return net
106
+
107
+ @master_only
108
+ def print_network(self, net):
109
+ """Print the str and parameter number of a network.
110
+
111
+ Args:
112
+ net (nn.Module)
113
+ """
114
+ if isinstance(net, (DataParallel, DistributedDataParallel)):
115
+ net_cls_str = (f'{net.__class__.__name__} - ' f'{net.module.__class__.__name__}')
116
+ else:
117
+ net_cls_str = f'{net.__class__.__name__}'
118
+
119
+ net = self.get_bare_model(net)
120
+ net_str = str(net)
121
+ net_params = sum(map(lambda x: x.numel(), net.parameters()))
122
+
123
+ logger.info(f'Network: {net_cls_str}, with parameters: {net_params:,d}')
124
+ logger.info(net_str)
125
+
126
+ def _set_lr(self, lr_groups_l):
127
+ """Set learning rate for warmup.
128
+
129
+ Args:
130
+ lr_groups_l (list): List for lr_groups, each for an optimizer.
131
+ """
132
+ for optimizer, lr_groups in zip(self.optimizers, lr_groups_l):
133
+ for param_group, lr in zip(optimizer.param_groups, lr_groups):
134
+ param_group['lr'] = lr
135
+
136
+ def _get_init_lr(self):
137
+ """Get the initial lr, which is set by the scheduler.
138
+ """
139
+ init_lr_groups_l = []
140
+ for optimizer in self.optimizers:
141
+ init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups])
142
+ return init_lr_groups_l
143
+
144
+ def update_learning_rate(self, current_iter, warmup_iter=-1):
145
+ """Update learning rate.
146
+
147
+ Args:
148
+ current_iter (int): Current iteration.
149
+ warmup_iter (int): Warmup iter numbers. -1 for no warmup.
150
+ Default: -1.
151
+ """
152
+ if current_iter > 1:
153
+ for scheduler in self.schedulers:
154
+ scheduler.step()
155
+ # set up warm-up learning rate
156
+ if current_iter < warmup_iter:
157
+ # get initial lr for each group
158
+ init_lr_g_l = self._get_init_lr()
159
+ # modify warming-up learning rates
160
+ # currently only support linearly warm up
161
+ warm_up_lr_l = []
162
+ for init_lr_g in init_lr_g_l:
163
+ warm_up_lr_l.append([v / warmup_iter * current_iter for v in init_lr_g])
164
+ # set learning rate
165
+ self._set_lr(warm_up_lr_l)
166
+
167
+ def get_current_learning_rate(self):
168
+ return [param_group['lr'] for param_group in self.optimizers[0].param_groups]
169
+
170
+ @master_only
171
+ def save_network(self, net, net_label, current_iter, param_key='params'):
172
+ """Save networks.
173
+
174
+ Args:
175
+ net (nn.Module | list[nn.Module]): Network(s) to be saved.
176
+ net_label (str): Network label.
177
+ current_iter (int): Current iter number.
178
+ param_key (str | list[str]): The parameter key(s) to save network.
179
+ Default: 'params'.
180
+ """
181
+ if current_iter == -1:
182
+ current_iter = 'latest'
183
+ save_filename = f'{net_label}_{current_iter}.pth'
184
+ save_path = os.path.join(self.opt['path']['models'], save_filename)
185
+
186
+ net = net if isinstance(net, list) else [net]
187
+ param_key = param_key if isinstance(param_key, list) else [param_key]
188
+ assert len(net) == len(param_key), 'The lengths of net and param_key should be the same.'
189
+
190
+ save_dict = {}
191
+ for net_, param_key_ in zip(net, param_key):
192
+ net_ = self.get_bare_model(net_)
193
+ state_dict = net_.state_dict()
194
+ for key, param in state_dict.items():
195
+ if key.startswith('module.'): # remove unnecessary 'module.'
196
+ key = key[7:]
197
+ state_dict[key] = param.cpu()
198
+ save_dict[param_key_] = state_dict
199
+
200
+ torch.save(save_dict, save_path)
201
+
202
+ def _print_different_keys_loading(self, crt_net, load_net, strict=True):
203
+ """Print keys with differnet name or different size when loading models.
204
+
205
+ 1. Print keys with differnet names.
206
+ 2. If strict=False, print the same key but with different tensor size.
207
+ It also ignore these keys with different sizes (not load).
208
+
209
+ Args:
210
+ crt_net (torch model): Current network.
211
+ load_net (dict): Loaded network.
212
+ strict (bool): Whether strictly loaded. Default: True.
213
+ """
214
+ crt_net = self.get_bare_model(crt_net)
215
+ crt_net = crt_net.state_dict()
216
+ crt_net_keys = set(crt_net.keys())
217
+ load_net_keys = set(load_net.keys())
218
+
219
+ if crt_net_keys != load_net_keys:
220
+ logger.warning('Current net - loaded net:')
221
+ for v in sorted(list(crt_net_keys - load_net_keys)):
222
+ logger.warning(f' {v}')
223
+ logger.warning('Loaded net - current net:')
224
+ for v in sorted(list(load_net_keys - crt_net_keys)):
225
+ logger.warning(f' {v}')
226
+
227
+ # check the size for the same keys
228
+ if not strict:
229
+ common_keys = crt_net_keys & load_net_keys
230
+ for k in common_keys:
231
+ if crt_net[k].size() != load_net[k].size():
232
+ logger.warning(f'Size different, ignore [{k}]: crt_net: '
233
+ f'{crt_net[k].shape}; load_net: {load_net[k].shape}')
234
+ load_net[k + '.ignore'] = load_net.pop(k)
235
+
236
+ def load_network(self, net, load_path, strict=True, param_key='params'):
237
+ """Load network.
238
+
239
+ Args:
240
+ load_path (str): The path of networks to be loaded.
241
+ net (nn.Module): Network.
242
+ strict (bool): Whether strictly loaded.
243
+ param_key (str): The parameter key of loaded network. If set to
244
+ None, use the root 'path'.
245
+ Default: 'params'.
246
+ """
247
+ net = self.get_bare_model(net)
248
+ logger.info(f'Loading {net.__class__.__name__} model from {load_path}.')
249
+ load_net = torch.load(load_path, map_location=lambda storage, loc: storage)
250
+ if param_key is not None:
251
+ if param_key not in load_net and 'params' in load_net:
252
+ param_key = 'params'
253
+ logger.info('Loading: params_ema does not exist, use params.')
254
+ load_net = load_net[param_key]
255
+ # remove unnecessary 'module.'
256
+ for k, v in deepcopy(load_net).items():
257
+ if k.startswith('module.'):
258
+ load_net[k[7:]] = v
259
+ load_net.pop(k)
260
+ self._print_different_keys_loading(net, load_net, strict)
261
+ net.load_state_dict(load_net, strict=strict)
262
+
263
+ @master_only
264
+ def save_training_state(self, epoch, current_iter):
265
+ """Save training states during training, which will be used for
266
+ resuming.
267
+
268
+ Args:
269
+ epoch (int): Current epoch.
270
+ current_iter (int): Current iteration.
271
+ """
272
+ if current_iter != -1:
273
+ state = {'epoch': epoch, 'iter': current_iter, 'optimizers': [], 'schedulers': []}
274
+ for o in self.optimizers:
275
+ state['optimizers'].append(o.state_dict())
276
+ for s in self.schedulers:
277
+ state['schedulers'].append(s.state_dict())
278
+ save_filename = f'{current_iter}.state'
279
+ save_path = os.path.join(self.opt['path']['training_states'], save_filename)
280
+ torch.save(state, save_path)
281
+
282
+ def resume_training(self, resume_state):
283
+ """Reload the optimizers and schedulers for resumed training.
284
+
285
+ Args:
286
+ resume_state (dict): Resume state.
287
+ """
288
+ resume_optimizers = resume_state['optimizers']
289
+ resume_schedulers = resume_state['schedulers']
290
+ assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers'
291
+ assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers'
292
+ for i, o in enumerate(resume_optimizers):
293
+ self.optimizers[i].load_state_dict(o)
294
+ for i, s in enumerate(resume_schedulers):
295
+ self.schedulers[i].load_state_dict(s)
296
+
297
+ def reduce_loss_dict(self, loss_dict):
298
+ """reduce loss dict.
299
+
300
+ In distributed training, it averages the losses among different GPUs .
301
+
302
+ Args:
303
+ loss_dict (OrderedDict): Loss dict.
304
+ """
305
+ with torch.no_grad():
306
+ if self.opt['dist']:
307
+ keys = []
308
+ losses = []
309
+ for name, value in loss_dict.items():
310
+ keys.append(name)
311
+ losses.append(value)
312
+ losses = torch.stack(losses, 0)
313
+ torch.distributed.reduce(losses, dst=0)
314
+ if self.opt['rank'] == 0:
315
+ losses /= self.opt['world_size']
316
+ loss_dict = {key: loss for key, loss in zip(keys, losses)}
317
+
318
+ log_dict = OrderedDict()
319
+ for name, value in loss_dict.items():
320
+ log_dict[name] = value.mean().item()
321
+
322
+ return log_dict
blissful_tuner/codeformer/basicsr/models/codeformer_idx_model.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from collections import OrderedDict
3
+ from os import path as osp
4
+ from tqdm import tqdm
5
+
6
+ from codeformer.basicsr.archs import build_network
7
+ from codeformer.basicsr.metrics import calculate_metric
8
+ from codeformer.basicsr.utils import get_root_logger, imwrite, tensor2img
9
+ from codeformer.basicsr.utils.registry import MODEL_REGISTRY
10
+ import torch.nn.functional as F
11
+ from .sr_model import SRModel
12
+
13
+
14
+ @MODEL_REGISTRY.register()
15
+ class CodeFormerIdxModel(SRModel):
16
+ def feed_data(self, data):
17
+ self.gt = data['gt'].to(self.device)
18
+ self.input = data['in'].to(self.device)
19
+ self.b = self.gt.shape[0]
20
+
21
+ if 'latent_gt' in data:
22
+ self.idx_gt = data['latent_gt'].to(self.device)
23
+ self.idx_gt = self.idx_gt.view(self.b, -1)
24
+ else:
25
+ self.idx_gt = None
26
+
27
+ def init_training_settings(self):
28
+ logger = get_root_logger()
29
+ train_opt = self.opt['train']
30
+
31
+ self.ema_decay = train_opt.get('ema_decay', 0)
32
+ if self.ema_decay > 0:
33
+ logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
34
+ # define network net_g with Exponential Moving Average (EMA)
35
+ # net_g_ema is used only for testing on one GPU and saving
36
+ # There is no need to wrap with DistributedDataParallel
37
+ self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
38
+ # load pretrained model
39
+ load_path = self.opt['path'].get('pretrain_network_g', None)
40
+ if load_path is not None:
41
+ self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
42
+ else:
43
+ self.model_ema(0) # copy net_g weight
44
+ self.net_g_ema.eval()
45
+
46
+ if self.opt['datasets']['train'].get('latent_gt_path', None) is not None:
47
+ self.generate_idx_gt = False
48
+ elif self.opt.get('network_vqgan', None) is not None:
49
+ self.hq_vqgan_fix = build_network(self.opt['network_vqgan']).to(self.device)
50
+ self.hq_vqgan_fix.eval()
51
+ self.generate_idx_gt = True
52
+ for param in self.hq_vqgan_fix.parameters():
53
+ param.requires_grad = False
54
+ else:
55
+ raise NotImplementedError(f'Shoule have network_vqgan config or pre-calculated latent code.')
56
+
57
+ logger.info(f'Need to generate latent GT code: {self.generate_idx_gt}')
58
+
59
+ self.hq_feat_loss = train_opt.get('use_hq_feat_loss', True)
60
+ self.feat_loss_weight = train_opt.get('feat_loss_weight', 1.0)
61
+ self.cross_entropy_loss = train_opt.get('cross_entropy_loss', True)
62
+ self.entropy_loss_weight = train_opt.get('entropy_loss_weight', 0.5)
63
+
64
+ self.net_g.train()
65
+
66
+ # set up optimizers and schedulers
67
+ self.setup_optimizers()
68
+ self.setup_schedulers()
69
+
70
+
71
+ def setup_optimizers(self):
72
+ train_opt = self.opt['train']
73
+ # optimizer g
74
+ optim_params_g = []
75
+ for k, v in self.net_g.named_parameters():
76
+ if v.requires_grad:
77
+ optim_params_g.append(v)
78
+ else:
79
+ logger = get_root_logger()
80
+ logger.warning(f'Params {k} will not be optimized.')
81
+ optim_type = train_opt['optim_g'].pop('type')
82
+ self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, **train_opt['optim_g'])
83
+ self.optimizers.append(self.optimizer_g)
84
+
85
+
86
+ def optimize_parameters(self, current_iter):
87
+ logger = get_root_logger()
88
+ # optimize net_g
89
+ self.optimizer_g.zero_grad()
90
+
91
+ if self.generate_idx_gt:
92
+ x = self.hq_vqgan_fix.encoder(self.gt)
93
+ _, _, quant_stats = self.hq_vqgan_fix.quantize(x)
94
+ min_encoding_indices = quant_stats['min_encoding_indices']
95
+ self.idx_gt = min_encoding_indices.view(self.b, -1)
96
+
97
+ if self.hq_feat_loss:
98
+ # quant_feats
99
+ quant_feat_gt = self.net_g.module.quantize.get_codebook_feat(self.idx_gt, shape=[self.b,16,16,256])
100
+
101
+ logits, lq_feat = self.net_g(self.input, w=0, code_only=True)
102
+
103
+ l_g_total = 0
104
+ loss_dict = OrderedDict()
105
+ # hq_feat_loss
106
+ if self.hq_feat_loss: # codebook loss
107
+ l_feat_encoder = torch.mean((quant_feat_gt.detach()-lq_feat)**2) * self.feat_loss_weight
108
+ l_g_total += l_feat_encoder
109
+ loss_dict['l_feat_encoder'] = l_feat_encoder
110
+
111
+ # cross_entropy_loss
112
+ if self.cross_entropy_loss:
113
+ # b(hw)n -> bn(hw)
114
+ cross_entropy_loss = F.cross_entropy(logits.permute(0, 2, 1), self.idx_gt) * self.entropy_loss_weight
115
+ l_g_total += cross_entropy_loss
116
+ loss_dict['cross_entropy_loss'] = cross_entropy_loss
117
+
118
+ l_g_total.backward()
119
+ self.optimizer_g.step()
120
+
121
+ if self.ema_decay > 0:
122
+ self.model_ema(decay=self.ema_decay)
123
+
124
+ self.log_dict = self.reduce_loss_dict(loss_dict)
125
+
126
+
127
+ def test(self):
128
+ with torch.no_grad():
129
+ if hasattr(self, 'net_g_ema'):
130
+ self.net_g_ema.eval()
131
+ self.output, _, _ = self.net_g_ema(self.input, w=0)
132
+ else:
133
+ logger = get_root_logger()
134
+ logger.warning('Do not have self.net_g_ema, use self.net_g.')
135
+ self.net_g.eval()
136
+ self.output, _, _ = self.net_g(self.input, w=0)
137
+ self.net_g.train()
138
+
139
+
140
+ def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
141
+ if self.opt['rank'] == 0:
142
+ self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
143
+
144
+
145
+ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
146
+ dataset_name = dataloader.dataset.opt['name']
147
+ with_metrics = self.opt['val'].get('metrics') is not None
148
+ if with_metrics:
149
+ self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
150
+ pbar = tqdm(total=len(dataloader), unit='image')
151
+
152
+ for idx, val_data in enumerate(dataloader):
153
+ img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
154
+ self.feed_data(val_data)
155
+ self.test()
156
+
157
+ visuals = self.get_current_visuals()
158
+ sr_img = tensor2img([visuals['result']])
159
+ if 'gt' in visuals:
160
+ gt_img = tensor2img([visuals['gt']])
161
+ del self.gt
162
+
163
+ # tentative for out of GPU memory
164
+ del self.lq
165
+ del self.output
166
+ torch.cuda.empty_cache()
167
+
168
+ if save_img:
169
+ if self.opt['is_train']:
170
+ save_img_path = osp.join(self.opt['path']['visualization'], img_name,
171
+ f'{img_name}_{current_iter}.png')
172
+ else:
173
+ if self.opt['val']['suffix']:
174
+ save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
175
+ f'{img_name}_{self.opt["val"]["suffix"]}.png')
176
+ else:
177
+ save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
178
+ f'{img_name}_{self.opt["name"]}.png')
179
+ imwrite(sr_img, save_img_path)
180
+
181
+ if with_metrics:
182
+ # calculate metrics
183
+ for name, opt_ in self.opt['val']['metrics'].items():
184
+ metric_data = dict(img1=sr_img, img2=gt_img)
185
+ self.metric_results[name] += calculate_metric(metric_data, opt_)
186
+ pbar.update(1)
187
+ pbar.set_description(f'Test {img_name}')
188
+ pbar.close()
189
+
190
+ if with_metrics:
191
+ for metric in self.metric_results.keys():
192
+ self.metric_results[metric] /= (idx + 1)
193
+
194
+ self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
195
+
196
+
197
+ def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
198
+ log_str = f'Validation {dataset_name}\n'
199
+ for metric, value in self.metric_results.items():
200
+ log_str += f'\t # {metric}: {value:.4f}\n'
201
+ logger = get_root_logger()
202
+ logger.info(log_str)
203
+ if tb_logger:
204
+ for metric, value in self.metric_results.items():
205
+ tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
206
+
207
+
208
+ def get_current_visuals(self):
209
+ out_dict = OrderedDict()
210
+ out_dict['gt'] = self.gt.detach().cpu()
211
+ out_dict['result'] = self.output.detach().cpu()
212
+ return out_dict
213
+
214
+
215
+ def save(self, epoch, current_iter):
216
+ if self.ema_decay > 0:
217
+ self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
218
+ else:
219
+ self.save_network(self.net_g, 'net_g', current_iter)
220
+ self.save_training_state(epoch, current_iter)
blissful_tuner/codeformer/basicsr/models/codeformer_joint_model.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from collections import OrderedDict
3
+ from os import path as osp
4
+ from tqdm import tqdm
5
+
6
+
7
+ from codeformer.basicsr.archs import build_network
8
+ from codeformer.basicsr.losses import build_loss
9
+ from codeformer.basicsr.metrics import calculate_metric
10
+ from codeformer.basicsr.utils import get_root_logger, imwrite, tensor2img
11
+ from codeformer.basicsr.utils.registry import MODEL_REGISTRY
12
+ import torch.nn.functional as F
13
+ from .sr_model import SRModel
14
+
15
+
16
+ @MODEL_REGISTRY.register()
17
+ class CodeFormerJointModel(SRModel):
18
+ def feed_data(self, data):
19
+ self.gt = data['gt'].to(self.device)
20
+ self.input = data['in'].to(self.device)
21
+ self.input_large_de = data['in_large_de'].to(self.device)
22
+ self.b = self.gt.shape[0]
23
+
24
+ if 'latent_gt' in data:
25
+ self.idx_gt = data['latent_gt'].to(self.device)
26
+ self.idx_gt = self.idx_gt.view(self.b, -1)
27
+ else:
28
+ self.idx_gt = None
29
+
30
+ def init_training_settings(self):
31
+ logger = get_root_logger()
32
+ train_opt = self.opt['train']
33
+
34
+ self.ema_decay = train_opt.get('ema_decay', 0)
35
+ if self.ema_decay > 0:
36
+ logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
37
+ # define network net_g with Exponential Moving Average (EMA)
38
+ # net_g_ema is used only for testing on one GPU and saving
39
+ # There is no need to wrap with DistributedDataParallel
40
+ self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
41
+ # load pretrained model
42
+ load_path = self.opt['path'].get('pretrain_network_g', None)
43
+ if load_path is not None:
44
+ self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
45
+ else:
46
+ self.model_ema(0) # copy net_g weight
47
+ self.net_g_ema.eval()
48
+
49
+ if self.opt['datasets']['train'].get('latent_gt_path', None) is not None:
50
+ self.generate_idx_gt = False
51
+ elif self.opt.get('network_vqgan', None) is not None:
52
+ self.hq_vqgan_fix = build_network(self.opt['network_vqgan']).to(self.device)
53
+ self.hq_vqgan_fix.eval()
54
+ self.generate_idx_gt = True
55
+ for param in self.hq_vqgan_fix.parameters():
56
+ param.requires_grad = False
57
+ else:
58
+ raise NotImplementedError(f'Shoule have network_vqgan config or pre-calculated latent code.')
59
+
60
+ logger.info(f'Need to generate latent GT code: {self.generate_idx_gt}')
61
+
62
+ self.hq_feat_loss = train_opt.get('use_hq_feat_loss', True)
63
+ self.feat_loss_weight = train_opt.get('feat_loss_weight', 1.0)
64
+ self.cross_entropy_loss = train_opt.get('cross_entropy_loss', True)
65
+ self.entropy_loss_weight = train_opt.get('entropy_loss_weight', 0.5)
66
+ self.scale_adaptive_gan_weight = train_opt.get('scale_adaptive_gan_weight', 0.8)
67
+
68
+ # define network net_d
69
+ self.net_d = build_network(self.opt['network_d'])
70
+ self.net_d = self.model_to_device(self.net_d)
71
+ self.print_network(self.net_d)
72
+
73
+ # load pretrained models
74
+ load_path = self.opt['path'].get('pretrain_network_d', None)
75
+ if load_path is not None:
76
+ self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True))
77
+
78
+ self.net_g.train()
79
+ self.net_d.train()
80
+
81
+ # define losses
82
+ if train_opt.get('pixel_opt'):
83
+ self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
84
+ else:
85
+ self.cri_pix = None
86
+
87
+ if train_opt.get('perceptual_opt'):
88
+ self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
89
+ else:
90
+ self.cri_perceptual = None
91
+
92
+ if train_opt.get('gan_opt'):
93
+ self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
94
+
95
+
96
+ self.fix_generator = train_opt.get('fix_generator', True)
97
+ logger.info(f'fix_generator: {self.fix_generator}')
98
+
99
+ self.net_g_start_iter = train_opt.get('net_g_start_iter', 0)
100
+ self.net_d_iters = train_opt.get('net_d_iters', 1)
101
+ self.net_d_start_iter = train_opt.get('net_d_start_iter', 0)
102
+
103
+ # set up optimizers and schedulers
104
+ self.setup_optimizers()
105
+ self.setup_schedulers()
106
+
107
+ def calculate_adaptive_weight(self, recon_loss, g_loss, last_layer, disc_weight_max):
108
+ recon_grads = torch.autograd.grad(recon_loss, last_layer, retain_graph=True)[0]
109
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
110
+
111
+ d_weight = torch.norm(recon_grads) / (torch.norm(g_grads) + 1e-4)
112
+ d_weight = torch.clamp(d_weight, 0.0, disc_weight_max).detach()
113
+ return d_weight
114
+
115
+ def setup_optimizers(self):
116
+ train_opt = self.opt['train']
117
+ # optimizer g
118
+ optim_params_g = []
119
+ for k, v in self.net_g.named_parameters():
120
+ if v.requires_grad:
121
+ optim_params_g.append(v)
122
+ else:
123
+ logger = get_root_logger()
124
+ logger.warning(f'Params {k} will not be optimized.')
125
+ optim_type = train_opt['optim_g'].pop('type')
126
+ self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, **train_opt['optim_g'])
127
+ self.optimizers.append(self.optimizer_g)
128
+ # optimizer d
129
+ optim_type = train_opt['optim_d'].pop('type')
130
+ self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d'])
131
+ self.optimizers.append(self.optimizer_d)
132
+
133
+ def gray_resize_for_identity(self, out, size=128):
134
+ out_gray = (0.2989 * out[:, 0, :, :] + 0.5870 * out[:, 1, :, :] + 0.1140 * out[:, 2, :, :])
135
+ out_gray = out_gray.unsqueeze(1)
136
+ out_gray = F.interpolate(out_gray, (size, size), mode='bilinear', align_corners=False)
137
+ return out_gray
138
+
139
+ def optimize_parameters(self, current_iter):
140
+ logger = get_root_logger()
141
+ # optimize net_g
142
+ for p in self.net_d.parameters():
143
+ p.requires_grad = False
144
+
145
+ self.optimizer_g.zero_grad()
146
+
147
+ if self.generate_idx_gt:
148
+ x = self.hq_vqgan_fix.encoder(self.gt)
149
+ output, _, quant_stats = self.hq_vqgan_fix.quantize(x)
150
+ min_encoding_indices = quant_stats['min_encoding_indices']
151
+ self.idx_gt = min_encoding_indices.view(self.b, -1)
152
+
153
+ if current_iter <= 40000: # small degradation
154
+ small_per_n = 1
155
+ w = 1
156
+ elif current_iter <= 80000: # small degradation
157
+ small_per_n = 1
158
+ w = 1.3
159
+ elif current_iter <= 120000: # large degradation
160
+ small_per_n = 120000
161
+ w = 0
162
+ else: # mixed degradation
163
+ small_per_n = 15
164
+ w = 1.3
165
+
166
+ if current_iter % small_per_n == 0:
167
+ self.output, logits, lq_feat = self.net_g(self.input, w=w, detach_16=True)
168
+ large_de = False
169
+ else:
170
+ logits, lq_feat = self.net_g(self.input_large_de, code_only=True)
171
+ large_de = True
172
+
173
+ if self.hq_feat_loss:
174
+ # quant_feats
175
+ quant_feat_gt = self.net_g.module.quantize.get_codebook_feat(self.idx_gt, shape=[self.b,16,16,256])
176
+
177
+ l_g_total = 0
178
+ loss_dict = OrderedDict()
179
+ if current_iter % self.net_d_iters == 0 and current_iter > self.net_g_start_iter:
180
+ # hq_feat_loss
181
+ if not 'transformer' in self.opt['network_g']['fix_modules']:
182
+ if self.hq_feat_loss: # codebook loss
183
+ l_feat_encoder = torch.mean((quant_feat_gt.detach()-lq_feat)**2) * self.feat_loss_weight
184
+ l_g_total += l_feat_encoder
185
+ loss_dict['l_feat_encoder'] = l_feat_encoder
186
+
187
+ # cross_entropy_loss
188
+ if self.cross_entropy_loss:
189
+ # b(hw)n -> bn(hw)
190
+ cross_entropy_loss = F.cross_entropy(logits.permute(0, 2, 1), self.idx_gt) * self.entropy_loss_weight
191
+ l_g_total += cross_entropy_loss
192
+ loss_dict['cross_entropy_loss'] = cross_entropy_loss
193
+
194
+ # pixel loss
195
+ if not large_de: # when large degradation don't need image-level loss
196
+ if self.cri_pix:
197
+ l_g_pix = self.cri_pix(self.output, self.gt)
198
+ l_g_total += l_g_pix
199
+ loss_dict['l_g_pix'] = l_g_pix
200
+
201
+ # perceptual loss
202
+ if self.cri_perceptual:
203
+ l_g_percep = self.cri_perceptual(self.output, self.gt)
204
+ l_g_total += l_g_percep
205
+ loss_dict['l_g_percep'] = l_g_percep
206
+
207
+ # gan loss
208
+ if current_iter > self.net_d_start_iter:
209
+ fake_g_pred = self.net_d(self.output)
210
+ l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
211
+ recon_loss = l_g_pix + l_g_percep
212
+ if not self.fix_generator:
213
+ last_layer = self.net_g.module.generator.blocks[-1].weight
214
+ d_weight = self.calculate_adaptive_weight(recon_loss, l_g_gan, last_layer, disc_weight_max=1.0)
215
+ else:
216
+ largest_fuse_size = self.opt['network_g']['connect_list'][-1]
217
+ last_layer = self.net_g.module.fuse_convs_dict[largest_fuse_size].shift[-1].weight
218
+ d_weight = self.calculate_adaptive_weight(recon_loss, l_g_gan, last_layer, disc_weight_max=1.0)
219
+
220
+ d_weight *= self.scale_adaptive_gan_weight # 0.8
221
+ loss_dict['d_weight'] = d_weight
222
+ l_g_total += d_weight * l_g_gan
223
+ loss_dict['l_g_gan'] = d_weight * l_g_gan
224
+
225
+ l_g_total.backward()
226
+ self.optimizer_g.step()
227
+
228
+ if self.ema_decay > 0:
229
+ self.model_ema(decay=self.ema_decay)
230
+
231
+ # optimize net_d
232
+ if not large_de:
233
+ if current_iter > self.net_d_start_iter:
234
+ for p in self.net_d.parameters():
235
+ p.requires_grad = True
236
+
237
+ self.optimizer_d.zero_grad()
238
+ # real
239
+ real_d_pred = self.net_d(self.gt)
240
+ l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
241
+ loss_dict['l_d_real'] = l_d_real
242
+ loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
243
+ l_d_real.backward()
244
+ # fake
245
+ fake_d_pred = self.net_d(self.output.detach())
246
+ l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
247
+ loss_dict['l_d_fake'] = l_d_fake
248
+ loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
249
+ l_d_fake.backward()
250
+
251
+ self.optimizer_d.step()
252
+
253
+ self.log_dict = self.reduce_loss_dict(loss_dict)
254
+
255
+
256
+ def test(self):
257
+ with torch.no_grad():
258
+ if hasattr(self, 'net_g_ema'):
259
+ self.net_g_ema.eval()
260
+ self.output, _, _ = self.net_g_ema(self.input, w=1)
261
+ else:
262
+ logger = get_root_logger()
263
+ logger.warning('Do not have self.net_g_ema, use self.net_g.')
264
+ self.net_g.eval()
265
+ self.output, _, _ = self.net_g(self.input, w=1)
266
+ self.net_g.train()
267
+
268
+
269
+ def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
270
+ if self.opt['rank'] == 0:
271
+ self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
272
+
273
+
274
+ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
275
+ dataset_name = dataloader.dataset.opt['name']
276
+ with_metrics = self.opt['val'].get('metrics') is not None
277
+ if with_metrics:
278
+ self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
279
+ pbar = tqdm(total=len(dataloader), unit='image')
280
+
281
+ for idx, val_data in enumerate(dataloader):
282
+ img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
283
+ self.feed_data(val_data)
284
+ self.test()
285
+
286
+ visuals = self.get_current_visuals()
287
+ sr_img = tensor2img([visuals['result']])
288
+ if 'gt' in visuals:
289
+ gt_img = tensor2img([visuals['gt']])
290
+ del self.gt
291
+
292
+ # tentative for out of GPU memory
293
+ del self.lq
294
+ del self.output
295
+ torch.cuda.empty_cache()
296
+
297
+ if save_img:
298
+ if self.opt['is_train']:
299
+ save_img_path = osp.join(self.opt['path']['visualization'], img_name,
300
+ f'{img_name}_{current_iter}.png')
301
+ else:
302
+ if self.opt['val']['suffix']:
303
+ save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
304
+ f'{img_name}_{self.opt["val"]["suffix"]}.png')
305
+ else:
306
+ save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
307
+ f'{img_name}_{self.opt["name"]}.png')
308
+ imwrite(sr_img, save_img_path)
309
+
310
+ if with_metrics:
311
+ # calculate metrics
312
+ for name, opt_ in self.opt['val']['metrics'].items():
313
+ metric_data = dict(img1=sr_img, img2=gt_img)
314
+ self.metric_results[name] += calculate_metric(metric_data, opt_)
315
+ pbar.update(1)
316
+ pbar.set_description(f'Test {img_name}')
317
+ pbar.close()
318
+
319
+ if with_metrics:
320
+ for metric in self.metric_results.keys():
321
+ self.metric_results[metric] /= (idx + 1)
322
+
323
+ self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
324
+
325
+
326
+ def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
327
+ log_str = f'Validation {dataset_name}\n'
328
+ for metric, value in self.metric_results.items():
329
+ log_str += f'\t # {metric}: {value:.4f}\n'
330
+ logger = get_root_logger()
331
+ logger.info(log_str)
332
+ if tb_logger:
333
+ for metric, value in self.metric_results.items():
334
+ tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
335
+
336
+
337
+ def get_current_visuals(self):
338
+ out_dict = OrderedDict()
339
+ out_dict['gt'] = self.gt.detach().cpu()
340
+ out_dict['result'] = self.output.detach().cpu()
341
+ return out_dict
342
+
343
+
344
+ def save(self, epoch, current_iter):
345
+ if self.ema_decay > 0:
346
+ self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
347
+ else:
348
+ self.save_network(self.net_g, 'net_g', current_iter)
349
+ self.save_network(self.net_d, 'net_d', current_iter)
350
+ self.save_training_state(epoch, current_iter)
blissful_tuner/codeformer/basicsr/models/codeformer_model.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from collections import OrderedDict
3
+ from os import path as osp
4
+ from tqdm import tqdm
5
+
6
+ from codeformer.basicsr.archs import build_network
7
+ from codeformer.basicsr.losses import build_loss
8
+ from codeformer.basicsr.metrics import calculate_metric
9
+ from codeformer.basicsr.utils import get_root_logger, imwrite, tensor2img
10
+ from codeformer.basicsr.utils.registry import MODEL_REGISTRY
11
+ import torch.nn.functional as F
12
+ from .sr_model import SRModel
13
+
14
+
15
+ @MODEL_REGISTRY.register()
16
+ class CodeFormerModel(SRModel):
17
+ def feed_data(self, data):
18
+ self.gt = data['gt'].to(self.device)
19
+ self.input = data['in'].to(self.device)
20
+ self.b = self.gt.shape[0]
21
+
22
+ if 'latent_gt' in data:
23
+ self.idx_gt = data['latent_gt'].to(self.device)
24
+ self.idx_gt = self.idx_gt.view(self.b, -1)
25
+ else:
26
+ self.idx_gt = None
27
+
28
+ def init_training_settings(self):
29
+ logger = get_root_logger()
30
+ train_opt = self.opt['train']
31
+
32
+ self.ema_decay = train_opt.get('ema_decay', 0)
33
+ if self.ema_decay > 0:
34
+ logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
35
+ # define network net_g with Exponential Moving Average (EMA)
36
+ # net_g_ema is used only for testing on one GPU and saving
37
+ # There is no need to wrap with DistributedDataParallel
38
+ self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
39
+ # load pretrained model
40
+ load_path = self.opt['path'].get('pretrain_network_g', None)
41
+ if load_path is not None:
42
+ self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
43
+ else:
44
+ self.model_ema(0) # copy net_g weight
45
+ self.net_g_ema.eval()
46
+
47
+ if self.opt.get('network_vqgan', None) is not None and self.opt['datasets'].get('latent_gt_path') is None:
48
+ self.hq_vqgan_fix = build_network(self.opt['network_vqgan']).to(self.device)
49
+ self.hq_vqgan_fix.eval()
50
+ self.generate_idx_gt = True
51
+ for param in self.hq_vqgan_fix.parameters():
52
+ param.requires_grad = False
53
+ else:
54
+ self.generate_idx_gt = False
55
+
56
+ self.hq_feat_loss = train_opt.get('use_hq_feat_loss', True)
57
+ self.feat_loss_weight = train_opt.get('feat_loss_weight', 1.0)
58
+ self.cross_entropy_loss = train_opt.get('cross_entropy_loss', True)
59
+ self.entropy_loss_weight = train_opt.get('entropy_loss_weight', 0.5)
60
+ self.fidelity_weight = train_opt.get('fidelity_weight', 1.0)
61
+ self.scale_adaptive_gan_weight = train_opt.get('scale_adaptive_gan_weight', 0.8)
62
+
63
+
64
+ self.net_g.train()
65
+ # define network net_d
66
+ if self.fidelity_weight > 0:
67
+ self.net_d = build_network(self.opt['network_d'])
68
+ self.net_d = self.model_to_device(self.net_d)
69
+ self.print_network(self.net_d)
70
+
71
+ # load pretrained models
72
+ load_path = self.opt['path'].get('pretrain_network_d', None)
73
+ if load_path is not None:
74
+ self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True))
75
+
76
+ self.net_d.train()
77
+
78
+ # define losses
79
+ if train_opt.get('pixel_opt'):
80
+ self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
81
+ else:
82
+ self.cri_pix = None
83
+
84
+ if train_opt.get('perceptual_opt'):
85
+ self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
86
+ else:
87
+ self.cri_perceptual = None
88
+
89
+ if train_opt.get('gan_opt'):
90
+ self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
91
+
92
+
93
+ self.fix_generator = train_opt.get('fix_generator', True)
94
+ logger.info(f'fix_generator: {self.fix_generator}')
95
+
96
+ self.net_g_start_iter = train_opt.get('net_g_start_iter', 0)
97
+ self.net_d_iters = train_opt.get('net_d_iters', 1)
98
+ self.net_d_start_iter = train_opt.get('net_d_start_iter', 0)
99
+
100
+ # set up optimizers and schedulers
101
+ self.setup_optimizers()
102
+ self.setup_schedulers()
103
+
104
+ def calculate_adaptive_weight(self, recon_loss, g_loss, last_layer, disc_weight_max):
105
+ recon_grads = torch.autograd.grad(recon_loss, last_layer, retain_graph=True)[0]
106
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
107
+
108
+ d_weight = torch.norm(recon_grads) / (torch.norm(g_grads) + 1e-4)
109
+ d_weight = torch.clamp(d_weight, 0.0, disc_weight_max).detach()
110
+ return d_weight
111
+
112
+ def setup_optimizers(self):
113
+ train_opt = self.opt['train']
114
+ # optimizer g
115
+ optim_params_g = []
116
+ for k, v in self.net_g.named_parameters():
117
+ if v.requires_grad:
118
+ optim_params_g.append(v)
119
+ else:
120
+ logger = get_root_logger()
121
+ logger.warning(f'Params {k} will not be optimized.')
122
+ optim_type = train_opt['optim_g'].pop('type')
123
+ self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, **train_opt['optim_g'])
124
+ self.optimizers.append(self.optimizer_g)
125
+ # optimizer d
126
+ if self.fidelity_weight > 0:
127
+ optim_type = train_opt['optim_d'].pop('type')
128
+ self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d'])
129
+ self.optimizers.append(self.optimizer_d)
130
+
131
+ def gray_resize_for_identity(self, out, size=128):
132
+ out_gray = (0.2989 * out[:, 0, :, :] + 0.5870 * out[:, 1, :, :] + 0.1140 * out[:, 2, :, :])
133
+ out_gray = out_gray.unsqueeze(1)
134
+ out_gray = F.interpolate(out_gray, (size, size), mode='bilinear', align_corners=False)
135
+ return out_gray
136
+
137
+ def optimize_parameters(self, current_iter):
138
+ logger = get_root_logger()
139
+ # optimize net_g
140
+ for p in self.net_d.parameters():
141
+ p.requires_grad = False
142
+
143
+ self.optimizer_g.zero_grad()
144
+
145
+ if self.generate_idx_gt:
146
+ x = self.hq_vqgan_fix.encoder(self.gt)
147
+ output, _, quant_stats = self.hq_vqgan_fix.quantize(x)
148
+ min_encoding_indices = quant_stats['min_encoding_indices']
149
+ self.idx_gt = min_encoding_indices.view(self.b, -1)
150
+
151
+ if self.fidelity_weight > 0:
152
+ self.output, logits, lq_feat = self.net_g(self.input, w=self.fidelity_weight, detach_16=True)
153
+ else:
154
+ logits, lq_feat = self.net_g(self.input, w=0, code_only=True)
155
+
156
+ if self.hq_feat_loss:
157
+ # quant_feats
158
+ quant_feat_gt = self.net_g.module.quantize.get_codebook_feat(self.idx_gt, shape=[self.b,16,16,256])
159
+
160
+ l_g_total = 0
161
+ loss_dict = OrderedDict()
162
+ if current_iter % self.net_d_iters == 0 and current_iter > self.net_g_start_iter:
163
+ # hq_feat_loss
164
+ if self.hq_feat_loss: # codebook loss
165
+ l_feat_encoder = torch.mean((quant_feat_gt.detach()-lq_feat)**2) * self.feat_loss_weight
166
+ l_g_total += l_feat_encoder
167
+ loss_dict['l_feat_encoder'] = l_feat_encoder
168
+
169
+ # cross_entropy_loss
170
+ if self.cross_entropy_loss:
171
+ # b(hw)n -> bn(hw)
172
+ cross_entropy_loss = F.cross_entropy(logits.permute(0, 2, 1), self.idx_gt) * self.entropy_loss_weight
173
+ l_g_total += cross_entropy_loss
174
+ loss_dict['cross_entropy_loss'] = cross_entropy_loss
175
+
176
+ if self.fidelity_weight > 0: # when fidelity_weight == 0 don't need image-level loss
177
+ # pixel loss
178
+ if self.cri_pix:
179
+ l_g_pix = self.cri_pix(self.output, self.gt)
180
+ l_g_total += l_g_pix
181
+ loss_dict['l_g_pix'] = l_g_pix
182
+
183
+ # perceptual loss
184
+ if self.cri_perceptual:
185
+ l_g_percep = self.cri_perceptual(self.output, self.gt)
186
+ l_g_total += l_g_percep
187
+ loss_dict['l_g_percep'] = l_g_percep
188
+
189
+ # gan loss
190
+ if current_iter > self.net_d_start_iter:
191
+ fake_g_pred = self.net_d(self.output)
192
+ l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
193
+ recon_loss = l_g_pix + l_g_percep
194
+ if not self.fix_generator:
195
+ last_layer = self.net_g.module.generator.blocks[-1].weight
196
+ d_weight = self.calculate_adaptive_weight(recon_loss, l_g_gan, last_layer, disc_weight_max=1.0)
197
+ else:
198
+ largest_fuse_size = self.opt['network_g']['connect_list'][-1]
199
+ last_layer = self.net_g.module.fuse_convs_dict[largest_fuse_size].shift[-1].weight
200
+ d_weight = self.calculate_adaptive_weight(recon_loss, l_g_gan, last_layer, disc_weight_max=1.0)
201
+
202
+ d_weight *= self.scale_adaptive_gan_weight # 0.8
203
+ loss_dict['d_weight'] = d_weight
204
+ l_g_total += d_weight * l_g_gan
205
+ loss_dict['l_g_gan'] = d_weight * l_g_gan
206
+
207
+ l_g_total.backward()
208
+ self.optimizer_g.step()
209
+
210
+ if self.ema_decay > 0:
211
+ self.model_ema(decay=self.ema_decay)
212
+
213
+ # optimize net_d
214
+ if current_iter > self.net_d_start_iter and self.fidelity_weight > 0:
215
+ for p in self.net_d.parameters():
216
+ p.requires_grad = True
217
+
218
+ self.optimizer_d.zero_grad()
219
+ # real
220
+ real_d_pred = self.net_d(self.gt)
221
+ l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
222
+ loss_dict['l_d_real'] = l_d_real
223
+ loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
224
+ l_d_real.backward()
225
+ # fake
226
+ fake_d_pred = self.net_d(self.output.detach())
227
+ l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
228
+ loss_dict['l_d_fake'] = l_d_fake
229
+ loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
230
+ l_d_fake.backward()
231
+
232
+ self.optimizer_d.step()
233
+
234
+ self.log_dict = self.reduce_loss_dict(loss_dict)
235
+
236
+
237
+ def test(self):
238
+ with torch.no_grad():
239
+ if hasattr(self, 'net_g_ema'):
240
+ self.net_g_ema.eval()
241
+ self.output, _, _ = self.net_g_ema(self.input, w=self.fidelity_weight)
242
+ else:
243
+ logger = get_root_logger()
244
+ logger.warning('Do not have self.net_g_ema, use self.net_g.')
245
+ self.net_g.eval()
246
+ self.output, _, _ = self.net_g(self.input, w=self.fidelity_weight)
247
+ self.net_g.train()
248
+
249
+
250
+ def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
251
+ if self.opt['rank'] == 0:
252
+ self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
253
+
254
+
255
+ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
256
+ dataset_name = dataloader.dataset.opt['name']
257
+ with_metrics = self.opt['val'].get('metrics') is not None
258
+ if with_metrics:
259
+ self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
260
+ pbar = tqdm(total=len(dataloader), unit='image')
261
+
262
+ for idx, val_data in enumerate(dataloader):
263
+ img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
264
+ self.feed_data(val_data)
265
+ self.test()
266
+
267
+ visuals = self.get_current_visuals()
268
+ sr_img = tensor2img([visuals['result']])
269
+ if 'gt' in visuals:
270
+ gt_img = tensor2img([visuals['gt']])
271
+ del self.gt
272
+
273
+ # tentative for out of GPU memory
274
+ del self.lq
275
+ del self.output
276
+ torch.cuda.empty_cache()
277
+
278
+ if save_img:
279
+ if self.opt['is_train']:
280
+ save_img_path = osp.join(self.opt['path']['visualization'], img_name,
281
+ f'{img_name}_{current_iter}.png')
282
+ else:
283
+ if self.opt['val']['suffix']:
284
+ save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
285
+ f'{img_name}_{self.opt["val"]["suffix"]}.png')
286
+ else:
287
+ save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
288
+ f'{img_name}_{self.opt["name"]}.png')
289
+ imwrite(sr_img, save_img_path)
290
+
291
+ if with_metrics:
292
+ # calculate metrics
293
+ for name, opt_ in self.opt['val']['metrics'].items():
294
+ metric_data = dict(img1=sr_img, img2=gt_img)
295
+ self.metric_results[name] += calculate_metric(metric_data, opt_)
296
+ pbar.update(1)
297
+ pbar.set_description(f'Test {img_name}')
298
+ pbar.close()
299
+
300
+ if with_metrics:
301
+ for metric in self.metric_results.keys():
302
+ self.metric_results[metric] /= (idx + 1)
303
+
304
+ self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
305
+
306
+
307
+ def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
308
+ log_str = f'Validation {dataset_name}\n'
309
+ for metric, value in self.metric_results.items():
310
+ log_str += f'\t # {metric}: {value:.4f}\n'
311
+ logger = get_root_logger()
312
+ logger.info(log_str)
313
+ if tb_logger:
314
+ for metric, value in self.metric_results.items():
315
+ tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
316
+
317
+
318
+ def get_current_visuals(self):
319
+ out_dict = OrderedDict()
320
+ out_dict['gt'] = self.gt.detach().cpu()
321
+ out_dict['result'] = self.output.detach().cpu()
322
+ return out_dict
323
+
324
+
325
+ def save(self, epoch, current_iter):
326
+ if self.ema_decay > 0:
327
+ self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
328
+ else:
329
+ self.save_network(self.net_g, 'net_g', current_iter)
330
+ if self.fidelity_weight > 0:
331
+ self.save_network(self.net_d, 'net_d', current_iter)
332
+ self.save_training_state(epoch, current_iter)
blissful_tuner/codeformer/basicsr/models/lr_scheduler.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from collections import Counter
3
+ from torch.optim.lr_scheduler import _LRScheduler
4
+
5
+
6
+ class MultiStepRestartLR(_LRScheduler):
7
+ """ MultiStep with restarts learning rate scheme.
8
+
9
+ Args:
10
+ optimizer (torch.nn.optimizer): Torch optimizer.
11
+ milestones (list): Iterations that will decrease learning rate.
12
+ gamma (float): Decrease ratio. Default: 0.1.
13
+ restarts (list): Restart iterations. Default: [0].
14
+ restart_weights (list): Restart weights at each restart iteration.
15
+ Default: [1].
16
+ last_epoch (int): Used in _LRScheduler. Default: -1.
17
+ """
18
+
19
+ def __init__(self, optimizer, milestones, gamma=0.1, restarts=(0, ), restart_weights=(1, ), last_epoch=-1):
20
+ self.milestones = Counter(milestones)
21
+ self.gamma = gamma
22
+ self.restarts = restarts
23
+ self.restart_weights = restart_weights
24
+ assert len(self.restarts) == len(self.restart_weights), 'restarts and their weights do not match.'
25
+ super(MultiStepRestartLR, self).__init__(optimizer, last_epoch)
26
+
27
+ def get_lr(self):
28
+ if self.last_epoch in self.restarts:
29
+ weight = self.restart_weights[self.restarts.index(self.last_epoch)]
30
+ return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
31
+ if self.last_epoch not in self.milestones:
32
+ return [group['lr'] for group in self.optimizer.param_groups]
33
+ return [group['lr'] * self.gamma**self.milestones[self.last_epoch] for group in self.optimizer.param_groups]
34
+
35
+
36
+ def get_position_from_periods(iteration, cumulative_period):
37
+ """Get the position from a period list.
38
+
39
+ It will return the index of the right-closest number in the period list.
40
+ For example, the cumulative_period = [100, 200, 300, 400],
41
+ if iteration == 50, return 0;
42
+ if iteration == 210, return 2;
43
+ if iteration == 300, return 2.
44
+
45
+ Args:
46
+ iteration (int): Current iteration.
47
+ cumulative_period (list[int]): Cumulative period list.
48
+
49
+ Returns:
50
+ int: The position of the right-closest number in the period list.
51
+ """
52
+ for i, period in enumerate(cumulative_period):
53
+ if iteration <= period:
54
+ return i
55
+
56
+
57
+ class CosineAnnealingRestartLR(_LRScheduler):
58
+ """ Cosine annealing with restarts learning rate scheme.
59
+
60
+ An example of config:
61
+ periods = [10, 10, 10, 10]
62
+ restart_weights = [1, 0.5, 0.5, 0.5]
63
+ eta_min=1e-7
64
+
65
+ It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the
66
+ scheduler will restart with the weights in restart_weights.
67
+
68
+ Args:
69
+ optimizer (torch.nn.optimizer): Torch optimizer.
70
+ periods (list): Period for each cosine anneling cycle.
71
+ restart_weights (list): Restart weights at each restart iteration.
72
+ Default: [1].
73
+ eta_min (float): The mimimum lr. Default: 0.
74
+ last_epoch (int): Used in _LRScheduler. Default: -1.
75
+ """
76
+
77
+ def __init__(self, optimizer, periods, restart_weights=(1, ), eta_min=0, last_epoch=-1):
78
+ self.periods = periods
79
+ self.restart_weights = restart_weights
80
+ self.eta_min = eta_min
81
+ assert (len(self.periods) == len(
82
+ self.restart_weights)), 'periods and restart_weights should have the same length.'
83
+ self.cumulative_period = [sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))]
84
+ super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch)
85
+
86
+ def get_lr(self):
87
+ idx = get_position_from_periods(self.last_epoch, self.cumulative_period)
88
+ current_weight = self.restart_weights[idx]
89
+ nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1]
90
+ current_period = self.periods[idx]
91
+
92
+ return [
93
+ self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) *
94
+ (1 + math.cos(math.pi * ((self.last_epoch - nearest_restart) / current_period)))
95
+ for base_lr in self.base_lrs
96
+ ]
blissful_tuner/codeformer/basicsr/models/sr_model.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from collections import OrderedDict
3
+ from os import path as osp
4
+ from tqdm import tqdm
5
+
6
+ from codeformer.basicsr.archs import build_network
7
+ from codeformer.basicsr.losses import build_loss
8
+ from codeformer.basicsr.metrics import calculate_metric
9
+ from codeformer.basicsr.utils import get_root_logger, imwrite, tensor2img
10
+ from codeformer.basicsr.utils.registry import MODEL_REGISTRY
11
+ from .base_model import BaseModel
12
+
13
+ @MODEL_REGISTRY.register()
14
+ class SRModel(BaseModel):
15
+ """Base SR model for single image super-resolution."""
16
+
17
+ def __init__(self, opt):
18
+ super(SRModel, self).__init__(opt)
19
+
20
+ # define network
21
+ self.net_g = build_network(opt['network_g'])
22
+ self.net_g = self.model_to_device(self.net_g)
23
+ self.print_network(self.net_g)
24
+
25
+ # load pretrained models
26
+ load_path = self.opt['path'].get('pretrain_network_g', None)
27
+ if load_path is not None:
28
+ param_key = self.opt['path'].get('param_key_g', 'params')
29
+ self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key)
30
+
31
+ if self.is_train:
32
+ self.init_training_settings()
33
+
34
+ def init_training_settings(self):
35
+ self.net_g.train()
36
+ train_opt = self.opt['train']
37
+
38
+ self.ema_decay = train_opt.get('ema_decay', 0)
39
+ if self.ema_decay > 0:
40
+ logger = get_root_logger()
41
+ logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
42
+ # define network net_g with Exponential Moving Average (EMA)
43
+ # net_g_ema is used only for testing on one GPU and saving
44
+ # There is no need to wrap with DistributedDataParallel
45
+ self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
46
+ # load pretrained model
47
+ load_path = self.opt['path'].get('pretrain_network_g', None)
48
+ if load_path is not None:
49
+ self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
50
+ else:
51
+ self.model_ema(0) # copy net_g weight
52
+ self.net_g_ema.eval()
53
+
54
+ # define losses
55
+ if train_opt.get('pixel_opt'):
56
+ self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
57
+ else:
58
+ self.cri_pix = None
59
+
60
+ if train_opt.get('perceptual_opt'):
61
+ self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
62
+ else:
63
+ self.cri_perceptual = None
64
+
65
+ if self.cri_pix is None and self.cri_perceptual is None:
66
+ raise ValueError('Both pixel and perceptual losses are None.')
67
+
68
+ # set up optimizers and schedulers
69
+ self.setup_optimizers()
70
+ self.setup_schedulers()
71
+
72
+ def setup_optimizers(self):
73
+ train_opt = self.opt['train']
74
+ optim_params = []
75
+ for k, v in self.net_g.named_parameters():
76
+ if v.requires_grad:
77
+ optim_params.append(v)
78
+ else:
79
+ logger = get_root_logger()
80
+ logger.warning(f'Params {k} will not be optimized.')
81
+
82
+ optim_type = train_opt['optim_g'].pop('type')
83
+ self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g'])
84
+ self.optimizers.append(self.optimizer_g)
85
+
86
+ def feed_data(self, data):
87
+ self.lq = data['lq'].to(self.device)
88
+ if 'gt' in data:
89
+ self.gt = data['gt'].to(self.device)
90
+
91
+ def optimize_parameters(self, current_iter):
92
+ self.optimizer_g.zero_grad()
93
+ self.output = self.net_g(self.lq)
94
+
95
+ l_total = 0
96
+ loss_dict = OrderedDict()
97
+ # pixel loss
98
+ if self.cri_pix:
99
+ l_pix = self.cri_pix(self.output, self.gt)
100
+ l_total += l_pix
101
+ loss_dict['l_pix'] = l_pix
102
+ # perceptual loss
103
+ if self.cri_perceptual:
104
+ l_percep, l_style = self.cri_perceptual(self.output, self.gt)
105
+ if l_percep is not None:
106
+ l_total += l_percep
107
+ loss_dict['l_percep'] = l_percep
108
+ if l_style is not None:
109
+ l_total += l_style
110
+ loss_dict['l_style'] = l_style
111
+
112
+ l_total.backward()
113
+ self.optimizer_g.step()
114
+
115
+ self.log_dict = self.reduce_loss_dict(loss_dict)
116
+
117
+ if self.ema_decay > 0:
118
+ self.model_ema(decay=self.ema_decay)
119
+
120
+ def test(self):
121
+ if hasattr(self, 'ema_decay'):
122
+ self.net_g_ema.eval()
123
+ with torch.no_grad():
124
+ self.output = self.net_g_ema(self.lq)
125
+ else:
126
+ self.net_g.eval()
127
+ with torch.no_grad():
128
+ self.output = self.net_g(self.lq)
129
+ self.net_g.train()
130
+
131
+ def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
132
+ if self.opt['rank'] == 0:
133
+ self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
134
+
135
+ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
136
+ dataset_name = dataloader.dataset.opt['name']
137
+ with_metrics = self.opt['val'].get('metrics') is not None
138
+ if with_metrics:
139
+ self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
140
+ pbar = tqdm(total=len(dataloader), unit='image')
141
+
142
+ for idx, val_data in enumerate(dataloader):
143
+ img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
144
+ self.feed_data(val_data)
145
+ self.test()
146
+
147
+ visuals = self.get_current_visuals()
148
+ sr_img = tensor2img([visuals['result']])
149
+ if 'gt' in visuals:
150
+ gt_img = tensor2img([visuals['gt']])
151
+ del self.gt
152
+
153
+ # tentative for out of GPU memory
154
+ del self.lq
155
+ del self.output
156
+ torch.cuda.empty_cache()
157
+
158
+ if save_img:
159
+ if self.opt['is_train']:
160
+ save_img_path = osp.join(self.opt['path']['visualization'], img_name,
161
+ f'{img_name}_{current_iter}.png')
162
+ else:
163
+ if self.opt['val']['suffix']:
164
+ save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
165
+ f'{img_name}_{self.opt["val"]["suffix"]}.png')
166
+ else:
167
+ save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
168
+ f'{img_name}_{self.opt["name"]}.png')
169
+ imwrite(sr_img, save_img_path)
170
+
171
+ if with_metrics:
172
+ # calculate metrics
173
+ for name, opt_ in self.opt['val']['metrics'].items():
174
+ metric_data = dict(img1=sr_img, img2=gt_img)
175
+ self.metric_results[name] += calculate_metric(metric_data, opt_)
176
+ pbar.update(1)
177
+ pbar.set_description(f'Test {img_name}')
178
+ pbar.close()
179
+
180
+ if with_metrics:
181
+ for metric in self.metric_results.keys():
182
+ self.metric_results[metric] /= (idx + 1)
183
+
184
+ self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
185
+
186
+ def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
187
+ log_str = f'Validation {dataset_name}\n'
188
+ for metric, value in self.metric_results.items():
189
+ log_str += f'\t # {metric}: {value:.4f}\n'
190
+ logger = get_root_logger()
191
+ logger.info(log_str)
192
+ if tb_logger:
193
+ for metric, value in self.metric_results.items():
194
+ tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
195
+
196
+ def get_current_visuals(self):
197
+ out_dict = OrderedDict()
198
+ out_dict['lq'] = self.lq.detach().cpu()
199
+ out_dict['result'] = self.output.detach().cpu()
200
+ if hasattr(self, 'gt'):
201
+ out_dict['gt'] = self.gt.detach().cpu()
202
+ return out_dict
203
+
204
+ def save(self, epoch, current_iter):
205
+ if hasattr(self, 'ema_decay'):
206
+ self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
207
+ else:
208
+ self.save_network(self.net_g, 'net_g', current_iter)
209
+ self.save_training_state(epoch, current_iter)
blissful_tuner/codeformer/basicsr/models/vqgan_model.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from collections import OrderedDict
3
+ from os import path as osp
4
+ from tqdm import tqdm
5
+
6
+ from codeformer.basicsr.archs import build_network
7
+ from codeformer.basicsr.losses import build_loss
8
+ from codeformer.basicsr.metrics import calculate_metric
9
+ from codeformer.basicsr.utils import get_root_logger, imwrite, tensor2img
10
+ from codeformer.basicsr.utils.registry import MODEL_REGISTRY
11
+ import torch.nn.functional as F
12
+ from .sr_model import SRModel
13
+
14
+
15
+ @MODEL_REGISTRY.register()
16
+ class VQGANModel(SRModel):
17
+ def feed_data(self, data):
18
+ self.gt = data['gt'].to(self.device)
19
+ self.b = self.gt.shape[0]
20
+
21
+
22
+ def init_training_settings(self):
23
+ logger = get_root_logger()
24
+ train_opt = self.opt['train']
25
+
26
+ self.ema_decay = train_opt.get('ema_decay', 0)
27
+ if self.ema_decay > 0:
28
+ logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
29
+ # define network net_g with Exponential Moving Average (EMA)
30
+ # net_g_ema is used only for testing on one GPU and saving
31
+ # There is no need to wrap with DistributedDataParallel
32
+ self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
33
+ # load pretrained model
34
+ load_path = self.opt['path'].get('pretrain_network_g', None)
35
+ if load_path is not None:
36
+ self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
37
+ else:
38
+ self.model_ema(0) # copy net_g weight
39
+ self.net_g_ema.eval()
40
+
41
+ # define network net_d
42
+ self.net_d = build_network(self.opt['network_d'])
43
+ self.net_d = self.model_to_device(self.net_d)
44
+ self.print_network(self.net_d)
45
+
46
+ # load pretrained models
47
+ load_path = self.opt['path'].get('pretrain_network_d', None)
48
+ if load_path is not None:
49
+ self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True))
50
+
51
+ self.net_g.train()
52
+ self.net_d.train()
53
+
54
+ # define losses
55
+ if train_opt.get('pixel_opt'):
56
+ self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
57
+ else:
58
+ self.cri_pix = None
59
+
60
+ if train_opt.get('perceptual_opt'):
61
+ self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
62
+ else:
63
+ self.cri_perceptual = None
64
+
65
+ if train_opt.get('gan_opt'):
66
+ self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
67
+
68
+ if train_opt.get('codebook_opt'):
69
+ self.l_weight_codebook = train_opt['codebook_opt'].get('loss_weight', 1.0)
70
+ else:
71
+ self.l_weight_codebook = 1.0
72
+
73
+ self.vqgan_quantizer = self.opt['network_g']['quantizer']
74
+ logger.info(f'vqgan_quantizer: {self.vqgan_quantizer}')
75
+
76
+ self.net_g_start_iter = train_opt.get('net_g_start_iter', 0)
77
+ self.net_d_iters = train_opt.get('net_d_iters', 1)
78
+ self.net_d_start_iter = train_opt.get('net_d_start_iter', 0)
79
+ self.disc_weight = train_opt.get('disc_weight', 0.8)
80
+
81
+ # set up optimizers and schedulers
82
+ self.setup_optimizers()
83
+ self.setup_schedulers()
84
+
85
+ def calculate_adaptive_weight(self, recon_loss, g_loss, last_layer, disc_weight_max):
86
+ recon_grads = torch.autograd.grad(recon_loss, last_layer, retain_graph=True)[0]
87
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
88
+
89
+ d_weight = torch.norm(recon_grads) / (torch.norm(g_grads) + 1e-4)
90
+ d_weight = torch.clamp(d_weight, 0.0, disc_weight_max).detach()
91
+ return d_weight
92
+
93
+ def adopt_weight(self, weight, global_step, threshold=0, value=0.):
94
+ if global_step < threshold:
95
+ weight = value
96
+ return weight
97
+
98
+ def setup_optimizers(self):
99
+ train_opt = self.opt['train']
100
+ # optimizer g
101
+ optim_params_g = []
102
+ for k, v in self.net_g.named_parameters():
103
+ if v.requires_grad:
104
+ optim_params_g.append(v)
105
+ else:
106
+ logger = get_root_logger()
107
+ logger.warning(f'Params {k} will not be optimized.')
108
+ optim_type = train_opt['optim_g'].pop('type')
109
+ self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, **train_opt['optim_g'])
110
+ self.optimizers.append(self.optimizer_g)
111
+ # optimizer d
112
+ optim_type = train_opt['optim_d'].pop('type')
113
+ self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d'])
114
+ self.optimizers.append(self.optimizer_d)
115
+
116
+
117
+ def optimize_parameters(self, current_iter):
118
+ logger = get_root_logger()
119
+ loss_dict = OrderedDict()
120
+ if self.opt['network_g']['quantizer'] == 'gumbel':
121
+ self.net_g.module.quantize.temperature = max(1/16, ((-1/160000) * current_iter) + 1)
122
+ if current_iter%1000 == 0:
123
+ logger.info(f'temperature: {self.net_g.module.quantize.temperature}')
124
+
125
+ # optimize net_g
126
+ for p in self.net_d.parameters():
127
+ p.requires_grad = False
128
+
129
+ self.optimizer_g.zero_grad()
130
+ self.output, l_codebook, quant_stats = self.net_g(self.gt)
131
+
132
+ l_codebook = l_codebook*self.l_weight_codebook
133
+
134
+ l_g_total = 0
135
+ if current_iter % self.net_d_iters == 0 and current_iter > self.net_g_start_iter:
136
+ # pixel loss
137
+ if self.cri_pix:
138
+ l_g_pix = self.cri_pix(self.output, self.gt)
139
+ l_g_total += l_g_pix
140
+ loss_dict['l_g_pix'] = l_g_pix
141
+ # perceptual loss
142
+ if self.cri_perceptual:
143
+ l_g_percep = self.cri_perceptual(self.output, self.gt)
144
+ l_g_total += l_g_percep
145
+ loss_dict['l_g_percep'] = l_g_percep
146
+
147
+ # gan loss
148
+ if current_iter > self.net_d_start_iter:
149
+ # fake_g_pred = self.net_d(self.output_1024)
150
+ fake_g_pred = self.net_d(self.output)
151
+ l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
152
+ recon_loss = l_g_total
153
+ last_layer = self.net_g.module.generator.blocks[-1].weight
154
+ d_weight = self.calculate_adaptive_weight(recon_loss, l_g_gan, last_layer, disc_weight_max=1.0)
155
+ d_weight *= self.adopt_weight(1, current_iter, self.net_d_start_iter)
156
+ d_weight *= self.disc_weight # tamming setting 0.8
157
+ l_g_total += d_weight * l_g_gan
158
+ loss_dict['l_g_gan'] = d_weight * l_g_gan
159
+
160
+ l_g_total += l_codebook
161
+ loss_dict['l_codebook'] = l_codebook
162
+
163
+ l_g_total.backward()
164
+ self.optimizer_g.step()
165
+
166
+ # optimize net_d
167
+ if current_iter > self.net_d_start_iter:
168
+ for p in self.net_d.parameters():
169
+ p.requires_grad = True
170
+
171
+ self.optimizer_d.zero_grad()
172
+ # real
173
+ real_d_pred = self.net_d(self.gt)
174
+ l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
175
+ loss_dict['l_d_real'] = l_d_real
176
+ loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
177
+ l_d_real.backward()
178
+ # fake
179
+ fake_d_pred = self.net_d(self.output.detach())
180
+ l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
181
+ loss_dict['l_d_fake'] = l_d_fake
182
+ loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
183
+ l_d_fake.backward()
184
+ self.optimizer_d.step()
185
+
186
+ self.log_dict = self.reduce_loss_dict(loss_dict)
187
+
188
+ if self.ema_decay > 0:
189
+ self.model_ema(decay=self.ema_decay)
190
+
191
+
192
+ def test(self):
193
+ with torch.no_grad():
194
+ if hasattr(self, 'net_g_ema'):
195
+ self.net_g_ema.eval()
196
+ self.output, _, _ = self.net_g_ema(self.gt)
197
+ else:
198
+ logger = get_root_logger()
199
+ logger.warning('Do not have self.net_g_ema, use self.net_g.')
200
+ self.net_g.eval()
201
+ self.output, _, _ = self.net_g(self.gt)
202
+ self.net_g.train()
203
+
204
+
205
+ def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
206
+ if self.opt['rank'] == 0:
207
+ self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
208
+
209
+
210
+ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
211
+ dataset_name = dataloader.dataset.opt['name']
212
+ with_metrics = self.opt['val'].get('metrics') is not None
213
+ if with_metrics:
214
+ self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
215
+ pbar = tqdm(total=len(dataloader), unit='image')
216
+
217
+ for idx, val_data in enumerate(dataloader):
218
+ img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
219
+ self.feed_data(val_data)
220
+ self.test()
221
+
222
+ visuals = self.get_current_visuals()
223
+ sr_img = tensor2img([visuals['result']])
224
+ if 'gt' in visuals:
225
+ gt_img = tensor2img([visuals['gt']])
226
+ del self.gt
227
+
228
+ # tentative for out of GPU memory
229
+ del self.lq
230
+ del self.output
231
+ torch.cuda.empty_cache()
232
+
233
+ if save_img:
234
+ if self.opt['is_train']:
235
+ save_img_path = osp.join(self.opt['path']['visualization'], img_name,
236
+ f'{img_name}_{current_iter}.png')
237
+ else:
238
+ if self.opt['val']['suffix']:
239
+ save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
240
+ f'{img_name}_{self.opt["val"]["suffix"]}.png')
241
+ else:
242
+ save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
243
+ f'{img_name}_{self.opt["name"]}.png')
244
+ imwrite(sr_img, save_img_path)
245
+
246
+ if with_metrics:
247
+ # calculate metrics
248
+ for name, opt_ in self.opt['val']['metrics'].items():
249
+ metric_data = dict(img1=sr_img, img2=gt_img)
250
+ self.metric_results[name] += calculate_metric(metric_data, opt_)
251
+ pbar.update(1)
252
+ pbar.set_description(f'Test {img_name}')
253
+ pbar.close()
254
+
255
+ if with_metrics:
256
+ for metric in self.metric_results.keys():
257
+ self.metric_results[metric] /= (idx + 1)
258
+
259
+ self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
260
+
261
+
262
+ def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
263
+ log_str = f'Validation {dataset_name}\n'
264
+ for metric, value in self.metric_results.items():
265
+ log_str += f'\t # {metric}: {value:.4f}\n'
266
+ logger = get_root_logger()
267
+ logger.info(log_str)
268
+ if tb_logger:
269
+ for metric, value in self.metric_results.items():
270
+ tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
271
+
272
+
273
+ def get_current_visuals(self):
274
+ out_dict = OrderedDict()
275
+ out_dict['gt'] = self.gt.detach().cpu()
276
+ out_dict['result'] = self.output.detach().cpu()
277
+ return out_dict
278
+
279
+ def save(self, epoch, current_iter):
280
+ if self.ema_decay > 0:
281
+ self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
282
+ else:
283
+ self.save_network(self.net_g, 'net_g', current_iter)
284
+ self.save_network(self.net_d, 'net_d', current_iter)
285
+ self.save_training_state(epoch, current_iter)
blissful_tuner/codeformer/basicsr/ops/__init__.py ADDED
File without changes
blissful_tuner/codeformer/basicsr/ops/dcn/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack, deform_conv,
2
+ modulated_deform_conv)
3
+
4
+ __all__ = [
5
+ 'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv',
6
+ 'modulated_deform_conv'
7
+ ]
blissful_tuner/codeformer/basicsr/ops/dcn/deform_conv.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn as nn
4
+ from torch.autograd import Function
5
+ from torch.autograd.function import once_differentiable
6
+ from torch.nn import functional as F
7
+ from torch.nn.modules.utils import _pair, _single
8
+
9
+ try:
10
+ from . import deform_conv_ext
11
+ except ImportError:
12
+ import os
13
+ BASICSR_JIT = os.getenv('BASICSR_JIT')
14
+ if BASICSR_JIT == 'True':
15
+ from torch.utils.cpp_extension import load
16
+ module_path = os.path.dirname(__file__)
17
+ deform_conv_ext = load(
18
+ 'deform_conv',
19
+ sources=[
20
+ os.path.join(module_path, 'src', 'deform_conv_ext.cpp'),
21
+ os.path.join(module_path, 'src', 'deform_conv_cuda.cpp'),
22
+ os.path.join(module_path, 'src', 'deform_conv_cuda_kernel.cu'),
23
+ ],
24
+ )
25
+
26
+
27
+ class DeformConvFunction(Function):
28
+
29
+ @staticmethod
30
+ def forward(ctx,
31
+ input,
32
+ offset,
33
+ weight,
34
+ stride=1,
35
+ padding=0,
36
+ dilation=1,
37
+ groups=1,
38
+ deformable_groups=1,
39
+ im2col_step=64):
40
+ if input is not None and input.dim() != 4:
41
+ raise ValueError(f'Expected 4D tensor as input, got {input.dim()}' 'D tensor instead.')
42
+ ctx.stride = _pair(stride)
43
+ ctx.padding = _pair(padding)
44
+ ctx.dilation = _pair(dilation)
45
+ ctx.groups = groups
46
+ ctx.deformable_groups = deformable_groups
47
+ ctx.im2col_step = im2col_step
48
+
49
+ ctx.save_for_backward(input, offset, weight)
50
+
51
+ output = input.new_empty(DeformConvFunction._output_size(input, weight, ctx.padding, ctx.dilation, ctx.stride))
52
+
53
+ ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones
54
+
55
+ if not input.is_cuda:
56
+ raise NotImplementedError
57
+ else:
58
+ cur_im2col_step = min(ctx.im2col_step, input.shape[0])
59
+ assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
60
+ deform_conv_ext.deform_conv_forward(input, weight,
61
+ offset, output, ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
62
+ weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
63
+ ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
64
+ ctx.deformable_groups, cur_im2col_step)
65
+ return output
66
+
67
+ @staticmethod
68
+ @once_differentiable
69
+ def backward(ctx, grad_output):
70
+ input, offset, weight = ctx.saved_tensors
71
+
72
+ grad_input = grad_offset = grad_weight = None
73
+
74
+ if not grad_output.is_cuda:
75
+ raise NotImplementedError
76
+ else:
77
+ cur_im2col_step = min(ctx.im2col_step, input.shape[0])
78
+ assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
79
+
80
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
81
+ grad_input = torch.zeros_like(input)
82
+ grad_offset = torch.zeros_like(offset)
83
+ deform_conv_ext.deform_conv_backward_input(input, offset, grad_output, grad_input,
84
+ grad_offset, weight, ctx.bufs_[0], weight.size(3),
85
+ weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
86
+ ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
87
+ ctx.deformable_groups, cur_im2col_step)
88
+
89
+ if ctx.needs_input_grad[2]:
90
+ grad_weight = torch.zeros_like(weight)
91
+ deform_conv_ext.deform_conv_backward_parameters(input, offset, grad_output, grad_weight,
92
+ ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
93
+ weight.size(2), ctx.stride[1], ctx.stride[0],
94
+ ctx.padding[1], ctx.padding[0], ctx.dilation[1],
95
+ ctx.dilation[0], ctx.groups, ctx.deformable_groups, 1,
96
+ cur_im2col_step)
97
+
98
+ return (grad_input, grad_offset, grad_weight, None, None, None, None, None)
99
+
100
+ @staticmethod
101
+ def _output_size(input, weight, padding, dilation, stride):
102
+ channels = weight.size(0)
103
+ output_size = (input.size(0), channels)
104
+ for d in range(input.dim() - 2):
105
+ in_size = input.size(d + 2)
106
+ pad = padding[d]
107
+ kernel = dilation[d] * (weight.size(d + 2) - 1) + 1
108
+ stride_ = stride[d]
109
+ output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
110
+ if not all(map(lambda s: s > 0, output_size)):
111
+ raise ValueError('convolution input is too small (output would be ' f'{"x".join(map(str, output_size))})')
112
+ return output_size
113
+
114
+
115
+ class ModulatedDeformConvFunction(Function):
116
+
117
+ @staticmethod
118
+ def forward(ctx,
119
+ input,
120
+ offset,
121
+ mask,
122
+ weight,
123
+ bias=None,
124
+ stride=1,
125
+ padding=0,
126
+ dilation=1,
127
+ groups=1,
128
+ deformable_groups=1):
129
+ ctx.stride = stride
130
+ ctx.padding = padding
131
+ ctx.dilation = dilation
132
+ ctx.groups = groups
133
+ ctx.deformable_groups = deformable_groups
134
+ ctx.with_bias = bias is not None
135
+ if not ctx.with_bias:
136
+ bias = input.new_empty(1) # fake tensor
137
+ if not input.is_cuda:
138
+ raise NotImplementedError
139
+ if weight.requires_grad or mask.requires_grad or offset.requires_grad \
140
+ or input.requires_grad:
141
+ ctx.save_for_backward(input, offset, mask, weight, bias)
142
+ output = input.new_empty(ModulatedDeformConvFunction._infer_shape(ctx, input, weight))
143
+ ctx._bufs = [input.new_empty(0), input.new_empty(0)]
144
+ deform_conv_ext.modulated_deform_conv_forward(input, weight, bias, ctx._bufs[0], offset, mask, output,
145
+ ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride,
146
+ ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
147
+ ctx.groups, ctx.deformable_groups, ctx.with_bias)
148
+ return output
149
+
150
+ @staticmethod
151
+ @once_differentiable
152
+ def backward(ctx, grad_output):
153
+ if not grad_output.is_cuda:
154
+ raise NotImplementedError
155
+ input, offset, mask, weight, bias = ctx.saved_tensors
156
+ grad_input = torch.zeros_like(input)
157
+ grad_offset = torch.zeros_like(offset)
158
+ grad_mask = torch.zeros_like(mask)
159
+ grad_weight = torch.zeros_like(weight)
160
+ grad_bias = torch.zeros_like(bias)
161
+ deform_conv_ext.modulated_deform_conv_backward(input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1],
162
+ grad_input, grad_weight, grad_bias, grad_offset, grad_mask,
163
+ grad_output, weight.shape[2], weight.shape[3], ctx.stride,
164
+ ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
165
+ ctx.groups, ctx.deformable_groups, ctx.with_bias)
166
+ if not ctx.with_bias:
167
+ grad_bias = None
168
+
169
+ return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None, None)
170
+
171
+ @staticmethod
172
+ def _infer_shape(ctx, input, weight):
173
+ n = input.size(0)
174
+ channels_out = weight.size(0)
175
+ height, width = input.shape[2:4]
176
+ kernel_h, kernel_w = weight.shape[2:4]
177
+ height_out = (height + 2 * ctx.padding - (ctx.dilation * (kernel_h - 1) + 1)) // ctx.stride + 1
178
+ width_out = (width + 2 * ctx.padding - (ctx.dilation * (kernel_w - 1) + 1)) // ctx.stride + 1
179
+ return n, channels_out, height_out, width_out
180
+
181
+
182
+ deform_conv = DeformConvFunction.apply
183
+ modulated_deform_conv = ModulatedDeformConvFunction.apply
184
+
185
+
186
+ class DeformConv(nn.Module):
187
+
188
+ def __init__(self,
189
+ in_channels,
190
+ out_channels,
191
+ kernel_size,
192
+ stride=1,
193
+ padding=0,
194
+ dilation=1,
195
+ groups=1,
196
+ deformable_groups=1,
197
+ bias=False):
198
+ super(DeformConv, self).__init__()
199
+
200
+ assert not bias
201
+ assert in_channels % groups == 0, \
202
+ f'in_channels {in_channels} is not divisible by groups {groups}'
203
+ assert out_channels % groups == 0, \
204
+ f'out_channels {out_channels} is not divisible ' \
205
+ f'by groups {groups}'
206
+
207
+ self.in_channels = in_channels
208
+ self.out_channels = out_channels
209
+ self.kernel_size = _pair(kernel_size)
210
+ self.stride = _pair(stride)
211
+ self.padding = _pair(padding)
212
+ self.dilation = _pair(dilation)
213
+ self.groups = groups
214
+ self.deformable_groups = deformable_groups
215
+ # enable compatibility with nn.Conv2d
216
+ self.transposed = False
217
+ self.output_padding = _single(0)
218
+
219
+ self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size))
220
+
221
+ self.reset_parameters()
222
+
223
+ def reset_parameters(self):
224
+ n = self.in_channels
225
+ for k in self.kernel_size:
226
+ n *= k
227
+ stdv = 1. / math.sqrt(n)
228
+ self.weight.data.uniform_(-stdv, stdv)
229
+
230
+ def forward(self, x, offset):
231
+ # To fix an assert error in deform_conv_cuda.cpp:128
232
+ # input image is smaller than kernel
233
+ input_pad = (x.size(2) < self.kernel_size[0] or x.size(3) < self.kernel_size[1])
234
+ if input_pad:
235
+ pad_h = max(self.kernel_size[0] - x.size(2), 0)
236
+ pad_w = max(self.kernel_size[1] - x.size(3), 0)
237
+ x = F.pad(x, (0, pad_w, 0, pad_h), 'constant', 0).contiguous()
238
+ offset = F.pad(offset, (0, pad_w, 0, pad_h), 'constant', 0).contiguous()
239
+ out = deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups,
240
+ self.deformable_groups)
241
+ if input_pad:
242
+ out = out[:, :, :out.size(2) - pad_h, :out.size(3) - pad_w].contiguous()
243
+ return out
244
+
245
+
246
+ class DeformConvPack(DeformConv):
247
+ """A Deformable Conv Encapsulation that acts as normal Conv layers.
248
+
249
+ Args:
250
+ in_channels (int): Same as nn.Conv2d.
251
+ out_channels (int): Same as nn.Conv2d.
252
+ kernel_size (int or tuple[int]): Same as nn.Conv2d.
253
+ stride (int or tuple[int]): Same as nn.Conv2d.
254
+ padding (int or tuple[int]): Same as nn.Conv2d.
255
+ dilation (int or tuple[int]): Same as nn.Conv2d.
256
+ groups (int): Same as nn.Conv2d.
257
+ bias (bool or str): If specified as `auto`, it will be decided by the
258
+ norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
259
+ False.
260
+ """
261
+
262
+ _version = 2
263
+
264
+ def __init__(self, *args, **kwargs):
265
+ super(DeformConvPack, self).__init__(*args, **kwargs)
266
+
267
+ self.conv_offset = nn.Conv2d(
268
+ self.in_channels,
269
+ self.deformable_groups * 2 * self.kernel_size[0] * self.kernel_size[1],
270
+ kernel_size=self.kernel_size,
271
+ stride=_pair(self.stride),
272
+ padding=_pair(self.padding),
273
+ dilation=_pair(self.dilation),
274
+ bias=True)
275
+ self.init_offset()
276
+
277
+ def init_offset(self):
278
+ self.conv_offset.weight.data.zero_()
279
+ self.conv_offset.bias.data.zero_()
280
+
281
+ def forward(self, x):
282
+ offset = self.conv_offset(x)
283
+ return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups,
284
+ self.deformable_groups)
285
+
286
+
287
+ class ModulatedDeformConv(nn.Module):
288
+
289
+ def __init__(self,
290
+ in_channels,
291
+ out_channels,
292
+ kernel_size,
293
+ stride=1,
294
+ padding=0,
295
+ dilation=1,
296
+ groups=1,
297
+ deformable_groups=1,
298
+ bias=True):
299
+ super(ModulatedDeformConv, self).__init__()
300
+ self.in_channels = in_channels
301
+ self.out_channels = out_channels
302
+ self.kernel_size = _pair(kernel_size)
303
+ self.stride = stride
304
+ self.padding = padding
305
+ self.dilation = dilation
306
+ self.groups = groups
307
+ self.deformable_groups = deformable_groups
308
+ self.with_bias = bias
309
+ # enable compatibility with nn.Conv2d
310
+ self.transposed = False
311
+ self.output_padding = _single(0)
312
+
313
+ self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size))
314
+ if bias:
315
+ self.bias = nn.Parameter(torch.Tensor(out_channels))
316
+ else:
317
+ self.register_parameter('bias', None)
318
+ self.init_weights()
319
+
320
+ def init_weights(self):
321
+ n = self.in_channels
322
+ for k in self.kernel_size:
323
+ n *= k
324
+ stdv = 1. / math.sqrt(n)
325
+ self.weight.data.uniform_(-stdv, stdv)
326
+ if self.bias is not None:
327
+ self.bias.data.zero_()
328
+
329
+ def forward(self, x, offset, mask):
330
+ return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation,
331
+ self.groups, self.deformable_groups)
332
+
333
+
334
+ class ModulatedDeformConvPack(ModulatedDeformConv):
335
+ """A ModulatedDeformable Conv Encapsulation that acts as normal Conv layers.
336
+
337
+ Args:
338
+ in_channels (int): Same as nn.Conv2d.
339
+ out_channels (int): Same as nn.Conv2d.
340
+ kernel_size (int or tuple[int]): Same as nn.Conv2d.
341
+ stride (int or tuple[int]): Same as nn.Conv2d.
342
+ padding (int or tuple[int]): Same as nn.Conv2d.
343
+ dilation (int or tuple[int]): Same as nn.Conv2d.
344
+ groups (int): Same as nn.Conv2d.
345
+ bias (bool or str): If specified as `auto`, it will be decided by the
346
+ norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
347
+ False.
348
+ """
349
+
350
+ _version = 2
351
+
352
+ def __init__(self, *args, **kwargs):
353
+ super(ModulatedDeformConvPack, self).__init__(*args, **kwargs)
354
+
355
+ self.conv_offset = nn.Conv2d(
356
+ self.in_channels,
357
+ self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
358
+ kernel_size=self.kernel_size,
359
+ stride=_pair(self.stride),
360
+ padding=_pair(self.padding),
361
+ dilation=_pair(self.dilation),
362
+ bias=True)
363
+ self.init_weights()
364
+
365
+ def init_weights(self):
366
+ super(ModulatedDeformConvPack, self).init_weights()
367
+ if hasattr(self, 'conv_offset'):
368
+ self.conv_offset.weight.data.zero_()
369
+ self.conv_offset.bias.data.zero_()
370
+
371
+ def forward(self, x):
372
+ out = self.conv_offset(x)
373
+ o1, o2, mask = torch.chunk(out, 3, dim=1)
374
+ offset = torch.cat((o1, o2), dim=1)
375
+ mask = torch.sigmoid(mask)
376
+ return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation,
377
+ self.groups, self.deformable_groups)
blissful_tuner/codeformer/basicsr/ops/dcn/src/deform_conv_cuda.cpp ADDED
@@ -0,0 +1,685 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // modify from
2
+ // https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
3
+
4
+ #include <torch/extension.h>
5
+ #include <ATen/DeviceGuard.h>
6
+
7
+ #include <cmath>
8
+ #include <vector>
9
+
10
+ void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset,
11
+ const int channels, const int height, const int width,
12
+ const int ksize_h, const int ksize_w, const int pad_h,
13
+ const int pad_w, const int stride_h, const int stride_w,
14
+ const int dilation_h, const int dilation_w,
15
+ const int parallel_imgs, const int deformable_group,
16
+ at::Tensor data_col);
17
+
18
+ void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset,
19
+ const int channels, const int height, const int width,
20
+ const int ksize_h, const int ksize_w, const int pad_h,
21
+ const int pad_w, const int stride_h, const int stride_w,
22
+ const int dilation_h, const int dilation_w,
23
+ const int parallel_imgs, const int deformable_group,
24
+ at::Tensor grad_im);
25
+
26
+ void deformable_col2im_coord(
27
+ const at::Tensor data_col, const at::Tensor data_im,
28
+ const at::Tensor data_offset, const int channels, const int height,
29
+ const int width, const int ksize_h, const int ksize_w, const int pad_h,
30
+ const int pad_w, const int stride_h, const int stride_w,
31
+ const int dilation_h, const int dilation_w, const int parallel_imgs,
32
+ const int deformable_group, at::Tensor grad_offset);
33
+
34
+ void modulated_deformable_im2col_cuda(
35
+ const at::Tensor data_im, const at::Tensor data_offset,
36
+ const at::Tensor data_mask, const int batch_size, const int channels,
37
+ const int height_im, const int width_im, const int height_col,
38
+ const int width_col, const int kernel_h, const int kenerl_w,
39
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
40
+ const int dilation_h, const int dilation_w, const int deformable_group,
41
+ at::Tensor data_col);
42
+
43
+ void modulated_deformable_col2im_cuda(
44
+ const at::Tensor data_col, const at::Tensor data_offset,
45
+ const at::Tensor data_mask, const int batch_size, const int channels,
46
+ const int height_im, const int width_im, const int height_col,
47
+ const int width_col, const int kernel_h, const int kenerl_w,
48
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
49
+ const int dilation_h, const int dilation_w, const int deformable_group,
50
+ at::Tensor grad_im);
51
+
52
+ void modulated_deformable_col2im_coord_cuda(
53
+ const at::Tensor data_col, const at::Tensor data_im,
54
+ const at::Tensor data_offset, const at::Tensor data_mask,
55
+ const int batch_size, const int channels, const int height_im,
56
+ const int width_im, const int height_col, const int width_col,
57
+ const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w,
58
+ const int stride_h, const int stride_w, const int dilation_h,
59
+ const int dilation_w, const int deformable_group, at::Tensor grad_offset,
60
+ at::Tensor grad_mask);
61
+
62
+ void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput,
63
+ at::Tensor weight, int kH, int kW, int dH, int dW, int padH,
64
+ int padW, int dilationH, int dilationW, int group,
65
+ int deformable_group) {
66
+ TORCH_CHECK(weight.ndimension() == 4,
67
+ "4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, "
68
+ "but got: %s",
69
+ weight.ndimension());
70
+
71
+ TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
72
+
73
+ TORCH_CHECK(kW > 0 && kH > 0,
74
+ "kernel size should be greater than zero, but got kH: %d kW: %d", kH,
75
+ kW);
76
+
77
+ TORCH_CHECK((weight.size(2) == kH && weight.size(3) == kW),
78
+ "kernel size should be consistent with weight, ",
79
+ "but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH,
80
+ kW, weight.size(2), weight.size(3));
81
+
82
+ TORCH_CHECK(dW > 0 && dH > 0,
83
+ "stride should be greater than zero, but got dH: %d dW: %d", dH, dW);
84
+
85
+ TORCH_CHECK(
86
+ dilationW > 0 && dilationH > 0,
87
+ "dilation should be greater than 0, but got dilationH: %d dilationW: %d",
88
+ dilationH, dilationW);
89
+
90
+ int ndim = input.ndimension();
91
+ int dimf = 0;
92
+ int dimh = 1;
93
+ int dimw = 2;
94
+
95
+ if (ndim == 4) {
96
+ dimf++;
97
+ dimh++;
98
+ dimw++;
99
+ }
100
+
101
+ TORCH_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s",
102
+ ndim);
103
+
104
+ long nInputPlane = weight.size(1) * group;
105
+ long inputHeight = input.size(dimh);
106
+ long inputWidth = input.size(dimw);
107
+ long nOutputPlane = weight.size(0);
108
+ long outputHeight =
109
+ (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
110
+ long outputWidth =
111
+ (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
112
+
113
+ TORCH_CHECK(nInputPlane % deformable_group == 0,
114
+ "input channels must divide deformable group size");
115
+
116
+ if (outputWidth < 1 || outputHeight < 1)
117
+ AT_ERROR(
118
+ "Given input size: (%ld x %ld x %ld). "
119
+ "Calculated output size: (%ld x %ld x %ld). Output size is too small",
120
+ nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight,
121
+ outputWidth);
122
+
123
+ TORCH_CHECK(input.size(1) == nInputPlane,
124
+ "invalid number of input planes, expected: %d, but got: %d",
125
+ nInputPlane, input.size(1));
126
+
127
+ TORCH_CHECK((inputHeight >= kH && inputWidth >= kW),
128
+ "input image is smaller than kernel");
129
+
130
+ TORCH_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth),
131
+ "invalid spatial size of offset, expected height: %d width: %d, but "
132
+ "got height: %d width: %d",
133
+ outputHeight, outputWidth, offset.size(2), offset.size(3));
134
+
135
+ TORCH_CHECK((offset.size(1) == deformable_group * 2 * kH * kW),
136
+ "invalid number of channels of offset");
137
+
138
+ if (gradOutput != NULL) {
139
+ TORCH_CHECK(gradOutput->size(dimf) == nOutputPlane,
140
+ "invalid number of gradOutput planes, expected: %d, but got: %d",
141
+ nOutputPlane, gradOutput->size(dimf));
142
+
143
+ TORCH_CHECK((gradOutput->size(dimh) == outputHeight &&
144
+ gradOutput->size(dimw) == outputWidth),
145
+ "invalid size of gradOutput, expected height: %d width: %d , but "
146
+ "got height: %d width: %d",
147
+ outputHeight, outputWidth, gradOutput->size(dimh),
148
+ gradOutput->size(dimw));
149
+ }
150
+ }
151
+
152
+ int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
153
+ at::Tensor offset, at::Tensor output,
154
+ at::Tensor columns, at::Tensor ones, int kW,
155
+ int kH, int dW, int dH, int padW, int padH,
156
+ int dilationW, int dilationH, int group,
157
+ int deformable_group, int im2col_step) {
158
+ // todo: resize columns to include im2col: done
159
+ // todo: add im2col_step as input
160
+ // todo: add new output buffer and transpose it to output (or directly
161
+ // transpose output) todo: possibly change data indexing because of
162
+ // parallel_imgs
163
+
164
+ shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW,
165
+ dilationH, dilationW, group, deformable_group);
166
+ at::DeviceGuard guard(input.device());
167
+
168
+ input = input.contiguous();
169
+ offset = offset.contiguous();
170
+ weight = weight.contiguous();
171
+
172
+ int batch = 1;
173
+ if (input.ndimension() == 3) {
174
+ // Force batch
175
+ batch = 0;
176
+ input.unsqueeze_(0);
177
+ offset.unsqueeze_(0);
178
+ }
179
+
180
+ // todo: assert batchsize dividable by im2col_step
181
+
182
+ long batchSize = input.size(0);
183
+ long nInputPlane = input.size(1);
184
+ long inputHeight = input.size(2);
185
+ long inputWidth = input.size(3);
186
+
187
+ long nOutputPlane = weight.size(0);
188
+
189
+ long outputWidth =
190
+ (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
191
+ long outputHeight =
192
+ (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
193
+
194
+ TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
195
+
196
+ output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane,
197
+ outputHeight, outputWidth});
198
+ columns = at::zeros(
199
+ {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
200
+ input.options());
201
+
202
+ if (ones.ndimension() != 2 ||
203
+ ones.size(0) * ones.size(1) < outputHeight * outputWidth) {
204
+ ones = at::ones({outputHeight, outputWidth}, input.options());
205
+ }
206
+
207
+ input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
208
+ inputHeight, inputWidth});
209
+ offset =
210
+ offset.view({batchSize / im2col_step, im2col_step,
211
+ deformable_group * 2 * kH * kW, outputHeight, outputWidth});
212
+
213
+ at::Tensor output_buffer =
214
+ at::zeros({batchSize / im2col_step, nOutputPlane,
215
+ im2col_step * outputHeight, outputWidth},
216
+ output.options());
217
+
218
+ output_buffer = output_buffer.view(
219
+ {output_buffer.size(0), group, output_buffer.size(1) / group,
220
+ output_buffer.size(2), output_buffer.size(3)});
221
+
222
+ for (int elt = 0; elt < batchSize / im2col_step; elt++) {
223
+ deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
224
+ inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
225
+ dilationW, im2col_step, deformable_group, columns);
226
+
227
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
228
+ weight = weight.view({group, weight.size(0) / group, weight.size(1),
229
+ weight.size(2), weight.size(3)});
230
+
231
+ for (int g = 0; g < group; g++) {
232
+ output_buffer[elt][g] = output_buffer[elt][g]
233
+ .flatten(1)
234
+ .addmm_(weight[g].flatten(1), columns[g])
235
+ .view_as(output_buffer[elt][g]);
236
+ }
237
+ }
238
+
239
+ output_buffer = output_buffer.view(
240
+ {output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2),
241
+ output_buffer.size(3), output_buffer.size(4)});
242
+
243
+ output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane,
244
+ im2col_step, outputHeight, outputWidth});
245
+ output_buffer.transpose_(1, 2);
246
+ output.copy_(output_buffer);
247
+ output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth});
248
+
249
+ input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
250
+ offset = offset.view(
251
+ {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
252
+
253
+ if (batch == 0) {
254
+ output = output.view({nOutputPlane, outputHeight, outputWidth});
255
+ input = input.view({nInputPlane, inputHeight, inputWidth});
256
+ offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
257
+ }
258
+
259
+ return 1;
260
+ }
261
+
262
+ int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
263
+ at::Tensor gradOutput, at::Tensor gradInput,
264
+ at::Tensor gradOffset, at::Tensor weight,
265
+ at::Tensor columns, int kW, int kH, int dW,
266
+ int dH, int padW, int padH, int dilationW,
267
+ int dilationH, int group,
268
+ int deformable_group, int im2col_step) {
269
+ shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW,
270
+ dilationH, dilationW, group, deformable_group);
271
+ at::DeviceGuard guard(input.device());
272
+
273
+ input = input.contiguous();
274
+ offset = offset.contiguous();
275
+ gradOutput = gradOutput.contiguous();
276
+ weight = weight.contiguous();
277
+
278
+ int batch = 1;
279
+
280
+ if (input.ndimension() == 3) {
281
+ // Force batch
282
+ batch = 0;
283
+ input = input.view({1, input.size(0), input.size(1), input.size(2)});
284
+ offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)});
285
+ gradOutput = gradOutput.view(
286
+ {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
287
+ }
288
+
289
+ long batchSize = input.size(0);
290
+ long nInputPlane = input.size(1);
291
+ long inputHeight = input.size(2);
292
+ long inputWidth = input.size(3);
293
+
294
+ long nOutputPlane = weight.size(0);
295
+
296
+ long outputWidth =
297
+ (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
298
+ long outputHeight =
299
+ (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
300
+
301
+ TORCH_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset");
302
+ gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
303
+ columns = at::zeros(
304
+ {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
305
+ input.options());
306
+
307
+ // change order of grad output
308
+ gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
309
+ nOutputPlane, outputHeight, outputWidth});
310
+ gradOutput.transpose_(1, 2);
311
+
312
+ gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane,
313
+ inputHeight, inputWidth});
314
+ input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
315
+ inputHeight, inputWidth});
316
+ gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step,
317
+ deformable_group * 2 * kH * kW, outputHeight,
318
+ outputWidth});
319
+ offset =
320
+ offset.view({batchSize / im2col_step, im2col_step,
321
+ deformable_group * 2 * kH * kW, outputHeight, outputWidth});
322
+
323
+ for (int elt = 0; elt < batchSize / im2col_step; elt++) {
324
+ // divide into groups
325
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
326
+ weight = weight.view({group, weight.size(0) / group, weight.size(1),
327
+ weight.size(2), weight.size(3)});
328
+ gradOutput = gradOutput.view(
329
+ {gradOutput.size(0), group, gradOutput.size(1) / group,
330
+ gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)});
331
+
332
+ for (int g = 0; g < group; g++) {
333
+ columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
334
+ gradOutput[elt][g].flatten(1), 0.0f, 1.0f);
335
+ }
336
+
337
+ columns =
338
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
339
+ gradOutput = gradOutput.view(
340
+ {gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2),
341
+ gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)});
342
+
343
+ deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane,
344
+ inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
345
+ dilationH, dilationW, im2col_step, deformable_group,
346
+ gradOffset[elt]);
347
+
348
+ deformable_col2im(columns, offset[elt], nInputPlane, inputHeight,
349
+ inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
350
+ dilationW, im2col_step, deformable_group, gradInput[elt]);
351
+ }
352
+
353
+ gradOutput.transpose_(1, 2);
354
+ gradOutput =
355
+ gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
356
+
357
+ gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
358
+ input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
359
+ gradOffset = gradOffset.view(
360
+ {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
361
+ offset = offset.view(
362
+ {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
363
+
364
+ if (batch == 0) {
365
+ gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
366
+ input = input.view({nInputPlane, inputHeight, inputWidth});
367
+ gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth});
368
+ offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
369
+ gradOffset =
370
+ gradOffset.view({offset.size(1), offset.size(2), offset.size(3)});
371
+ }
372
+
373
+ return 1;
374
+ }
375
+
376
+ int deform_conv_backward_parameters_cuda(
377
+ at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
378
+ at::Tensor gradWeight, // at::Tensor gradBias,
379
+ at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
380
+ int padW, int padH, int dilationW, int dilationH, int group,
381
+ int deformable_group, float scale, int im2col_step) {
382
+ // todo: transpose and reshape outGrad
383
+ // todo: reshape columns
384
+ // todo: add im2col_step as input
385
+
386
+ shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH,
387
+ padW, dilationH, dilationW, group, deformable_group);
388
+ at::DeviceGuard guard(input.device());
389
+
390
+ input = input.contiguous();
391
+ offset = offset.contiguous();
392
+ gradOutput = gradOutput.contiguous();
393
+
394
+ int batch = 1;
395
+
396
+ if (input.ndimension() == 3) {
397
+ // Force batch
398
+ batch = 0;
399
+ input = input.view(
400
+ at::IntList({1, input.size(0), input.size(1), input.size(2)}));
401
+ gradOutput = gradOutput.view(
402
+ {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
403
+ }
404
+
405
+ long batchSize = input.size(0);
406
+ long nInputPlane = input.size(1);
407
+ long inputHeight = input.size(2);
408
+ long inputWidth = input.size(3);
409
+
410
+ long nOutputPlane = gradWeight.size(0);
411
+
412
+ long outputWidth =
413
+ (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
414
+ long outputHeight =
415
+ (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
416
+
417
+ TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
418
+
419
+ columns = at::zeros(
420
+ {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
421
+ input.options());
422
+
423
+ gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
424
+ nOutputPlane, outputHeight, outputWidth});
425
+ gradOutput.transpose_(1, 2);
426
+
427
+ at::Tensor gradOutputBuffer = at::zeros_like(gradOutput);
428
+ gradOutputBuffer =
429
+ gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step,
430
+ outputHeight, outputWidth});
431
+ gradOutputBuffer.copy_(gradOutput);
432
+ gradOutputBuffer =
433
+ gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane,
434
+ im2col_step * outputHeight, outputWidth});
435
+
436
+ gradOutput.transpose_(1, 2);
437
+ gradOutput =
438
+ gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
439
+
440
+ input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
441
+ inputHeight, inputWidth});
442
+ offset =
443
+ offset.view({batchSize / im2col_step, im2col_step,
444
+ deformable_group * 2 * kH * kW, outputHeight, outputWidth});
445
+
446
+ for (int elt = 0; elt < batchSize / im2col_step; elt++) {
447
+ deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
448
+ inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
449
+ dilationW, im2col_step, deformable_group, columns);
450
+
451
+ // divide into group
452
+ gradOutputBuffer = gradOutputBuffer.view(
453
+ {gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group,
454
+ gradOutputBuffer.size(2), gradOutputBuffer.size(3)});
455
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
456
+ gradWeight =
457
+ gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1),
458
+ gradWeight.size(2), gradWeight.size(3)});
459
+
460
+ for (int g = 0; g < group; g++) {
461
+ gradWeight[g] = gradWeight[g]
462
+ .flatten(1)
463
+ .addmm_(gradOutputBuffer[elt][g].flatten(1),
464
+ columns[g].transpose(1, 0), 1.0, scale)
465
+ .view_as(gradWeight[g]);
466
+ }
467
+ gradOutputBuffer = gradOutputBuffer.view(
468
+ {gradOutputBuffer.size(0),
469
+ gradOutputBuffer.size(1) * gradOutputBuffer.size(2),
470
+ gradOutputBuffer.size(3), gradOutputBuffer.size(4)});
471
+ columns =
472
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
473
+ gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1),
474
+ gradWeight.size(2), gradWeight.size(3),
475
+ gradWeight.size(4)});
476
+ }
477
+
478
+ input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
479
+ offset = offset.view(
480
+ {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
481
+
482
+ if (batch == 0) {
483
+ gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
484
+ input = input.view({nInputPlane, inputHeight, inputWidth});
485
+ }
486
+
487
+ return 1;
488
+ }
489
+
490
+ void modulated_deform_conv_cuda_forward(
491
+ at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
492
+ at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
493
+ int kernel_h, int kernel_w, const int stride_h, const int stride_w,
494
+ const int pad_h, const int pad_w, const int dilation_h,
495
+ const int dilation_w, const int group, const int deformable_group,
496
+ const bool with_bias) {
497
+ TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
498
+ TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
499
+ at::DeviceGuard guard(input.device());
500
+
501
+ const int batch = input.size(0);
502
+ const int channels = input.size(1);
503
+ const int height = input.size(2);
504
+ const int width = input.size(3);
505
+
506
+ const int channels_out = weight.size(0);
507
+ const int channels_kernel = weight.size(1);
508
+ const int kernel_h_ = weight.size(2);
509
+ const int kernel_w_ = weight.size(3);
510
+
511
+ if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
512
+ AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
513
+ kernel_h_, kernel_w, kernel_h_, kernel_w_);
514
+ if (channels != channels_kernel * group)
515
+ AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
516
+ channels, channels_kernel * group);
517
+
518
+ const int height_out =
519
+ (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
520
+ const int width_out =
521
+ (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
522
+
523
+ if (ones.ndimension() != 2 ||
524
+ ones.size(0) * ones.size(1) < height_out * width_out) {
525
+ // Resize plane and fill with ones...
526
+ ones = at::ones({height_out, width_out}, input.options());
527
+ }
528
+
529
+ // resize output
530
+ output = output.view({batch, channels_out, height_out, width_out}).zero_();
531
+ // resize temporary columns
532
+ columns =
533
+ at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out},
534
+ input.options());
535
+
536
+ output = output.view({output.size(0), group, output.size(1) / group,
537
+ output.size(2), output.size(3)});
538
+
539
+ for (int b = 0; b < batch; b++) {
540
+ modulated_deformable_im2col_cuda(
541
+ input[b], offset[b], mask[b], 1, channels, height, width, height_out,
542
+ width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
543
+ dilation_h, dilation_w, deformable_group, columns);
544
+
545
+ // divide into group
546
+ weight = weight.view({group, weight.size(0) / group, weight.size(1),
547
+ weight.size(2), weight.size(3)});
548
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
549
+
550
+ for (int g = 0; g < group; g++) {
551
+ output[b][g] = output[b][g]
552
+ .flatten(1)
553
+ .addmm_(weight[g].flatten(1), columns[g])
554
+ .view_as(output[b][g]);
555
+ }
556
+
557
+ weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
558
+ weight.size(3), weight.size(4)});
559
+ columns =
560
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
561
+ }
562
+
563
+ output = output.view({output.size(0), output.size(1) * output.size(2),
564
+ output.size(3), output.size(4)});
565
+
566
+ if (with_bias) {
567
+ output += bias.view({1, bias.size(0), 1, 1});
568
+ }
569
+ }
570
+
571
+ void modulated_deform_conv_cuda_backward(
572
+ at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
573
+ at::Tensor offset, at::Tensor mask, at::Tensor columns,
574
+ at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
575
+ at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
576
+ int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
577
+ int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
578
+ const bool with_bias) {
579
+ TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
580
+ TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
581
+ at::DeviceGuard guard(input.device());
582
+
583
+ const int batch = input.size(0);
584
+ const int channels = input.size(1);
585
+ const int height = input.size(2);
586
+ const int width = input.size(3);
587
+
588
+ const int channels_kernel = weight.size(1);
589
+ const int kernel_h_ = weight.size(2);
590
+ const int kernel_w_ = weight.size(3);
591
+ if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
592
+ AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
593
+ kernel_h_, kernel_w, kernel_h_, kernel_w_);
594
+ if (channels != channels_kernel * group)
595
+ AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
596
+ channels, channels_kernel * group);
597
+
598
+ const int height_out =
599
+ (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
600
+ const int width_out =
601
+ (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
602
+
603
+ if (ones.ndimension() != 2 ||
604
+ ones.size(0) * ones.size(1) < height_out * width_out) {
605
+ // Resize plane and fill with ones...
606
+ ones = at::ones({height_out, width_out}, input.options());
607
+ }
608
+
609
+ grad_input = grad_input.view({batch, channels, height, width});
610
+ columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out},
611
+ input.options());
612
+
613
+ grad_output =
614
+ grad_output.view({grad_output.size(0), group, grad_output.size(1) / group,
615
+ grad_output.size(2), grad_output.size(3)});
616
+
617
+ for (int b = 0; b < batch; b++) {
618
+ // divide int group
619
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
620
+ weight = weight.view({group, weight.size(0) / group, weight.size(1),
621
+ weight.size(2), weight.size(3)});
622
+
623
+ for (int g = 0; g < group; g++) {
624
+ columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
625
+ grad_output[b][g].flatten(1), 0.0f, 1.0f);
626
+ }
627
+
628
+ columns =
629
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
630
+ weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
631
+ weight.size(3), weight.size(4)});
632
+
633
+ // gradient w.r.t. input coordinate data
634
+ modulated_deformable_col2im_coord_cuda(
635
+ columns, input[b], offset[b], mask[b], 1, channels, height, width,
636
+ height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h,
637
+ stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b],
638
+ grad_mask[b]);
639
+ // gradient w.r.t. input data
640
+ modulated_deformable_col2im_cuda(
641
+ columns, offset[b], mask[b], 1, channels, height, width, height_out,
642
+ width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
643
+ dilation_h, dilation_w, deformable_group, grad_input[b]);
644
+
645
+ // gradient w.r.t. weight, dWeight should accumulate across the batch and
646
+ // group
647
+ modulated_deformable_im2col_cuda(
648
+ input[b], offset[b], mask[b], 1, channels, height, width, height_out,
649
+ width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
650
+ dilation_h, dilation_w, deformable_group, columns);
651
+
652
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
653
+ grad_weight = grad_weight.view({group, grad_weight.size(0) / group,
654
+ grad_weight.size(1), grad_weight.size(2),
655
+ grad_weight.size(3)});
656
+ if (with_bias)
657
+ grad_bias = grad_bias.view({group, grad_bias.size(0) / group});
658
+
659
+ for (int g = 0; g < group; g++) {
660
+ grad_weight[g] =
661
+ grad_weight[g]
662
+ .flatten(1)
663
+ .addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1))
664
+ .view_as(grad_weight[g]);
665
+ if (with_bias) {
666
+ grad_bias[g] =
667
+ grad_bias[g]
668
+ .view({-1, 1})
669
+ .addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1}))
670
+ .view(-1);
671
+ }
672
+ }
673
+
674
+ columns =
675
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
676
+ grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1),
677
+ grad_weight.size(2), grad_weight.size(3),
678
+ grad_weight.size(4)});
679
+ if (with_bias)
680
+ grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)});
681
+ }
682
+ grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1),
683
+ grad_output.size(2), grad_output.size(3),
684
+ grad_output.size(4)});
685
+ }