Spaces:
Running
Running
Upload 303 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- .gitignore +18 -0
- .python-version +1 -0
- README.ja.md +377 -0
- README.md +269 -12
- base_hv_generate_video.py +936 -0
- base_wan_generate_video.py +1892 -0
- blissful_tuner/GIMMVFI.py +208 -0
- blissful_tuner/__init__.py +0 -0
- blissful_tuner/advanced_rope.py +112 -0
- blissful_tuner/blissful_args.py +131 -0
- blissful_tuner/blissful_settings.py +111 -0
- blissful_tuner/cfgzerostar.py +39 -0
- blissful_tuner/codeformer/LICENSE +15 -0
- blissful_tuner/codeformer/basicsr/VERSION +1 -0
- blissful_tuner/codeformer/basicsr/__init__.py +11 -0
- blissful_tuner/codeformer/basicsr/archs/__init__.py +25 -0
- blissful_tuner/codeformer/basicsr/archs/arcface_arch.py +245 -0
- blissful_tuner/codeformer/basicsr/archs/arch_util.py +318 -0
- blissful_tuner/codeformer/basicsr/archs/codeformer_arch.py +280 -0
- blissful_tuner/codeformer/basicsr/archs/rrdbnet_arch.py +119 -0
- blissful_tuner/codeformer/basicsr/archs/vgg_arch.py +161 -0
- blissful_tuner/codeformer/basicsr/archs/vqgan_arch.py +434 -0
- blissful_tuner/codeformer/basicsr/data/__init__.py +100 -0
- blissful_tuner/codeformer/basicsr/data/data_sampler.py +48 -0
- blissful_tuner/codeformer/basicsr/data/data_util.py +392 -0
- blissful_tuner/codeformer/basicsr/data/ffhq_blind_dataset.py +299 -0
- blissful_tuner/codeformer/basicsr/data/ffhq_blind_joint_dataset.py +324 -0
- blissful_tuner/codeformer/basicsr/data/gaussian_kernels.py +690 -0
- blissful_tuner/codeformer/basicsr/data/paired_image_dataset.py +101 -0
- blissful_tuner/codeformer/basicsr/data/prefetch_dataloader.py +125 -0
- blissful_tuner/codeformer/basicsr/data/transforms.py +165 -0
- blissful_tuner/codeformer/basicsr/losses/__init__.py +26 -0
- blissful_tuner/codeformer/basicsr/losses/loss_util.py +95 -0
- blissful_tuner/codeformer/basicsr/losses/losses.py +455 -0
- blissful_tuner/codeformer/basicsr/metrics/__init__.py +19 -0
- blissful_tuner/codeformer/basicsr/metrics/metric_util.py +45 -0
- blissful_tuner/codeformer/basicsr/metrics/psnr_ssim.py +128 -0
- blissful_tuner/codeformer/basicsr/models/__init__.py +30 -0
- blissful_tuner/codeformer/basicsr/models/base_model.py +322 -0
- blissful_tuner/codeformer/basicsr/models/codeformer_idx_model.py +220 -0
- blissful_tuner/codeformer/basicsr/models/codeformer_joint_model.py +350 -0
- blissful_tuner/codeformer/basicsr/models/codeformer_model.py +332 -0
- blissful_tuner/codeformer/basicsr/models/lr_scheduler.py +96 -0
- blissful_tuner/codeformer/basicsr/models/sr_model.py +209 -0
- blissful_tuner/codeformer/basicsr/models/vqgan_model.py +285 -0
- blissful_tuner/codeformer/basicsr/ops/__init__.py +0 -0
- blissful_tuner/codeformer/basicsr/ops/dcn/__init__.py +7 -0
- blissful_tuner/codeformer/basicsr/ops/dcn/deform_conv.py +377 -0
- blissful_tuner/codeformer/basicsr/ops/dcn/src/deform_conv_cuda.cpp +685 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
images/screenshot.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
.venv
|
3 |
+
venv/
|
4 |
+
logs/
|
5 |
+
uv.lock
|
6 |
+
.env
|
7 |
+
env/
|
8 |
+
outputs/
|
9 |
+
GIMM-VFI/
|
10 |
+
hunyuan/
|
11 |
+
temp_frames/
|
12 |
+
wan/wan2.1_i2v_480p_14B_bf16.safetensors
|
13 |
+
wan/wan2.1_t2v_14B_bf16.safetensors
|
14 |
+
wan/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth
|
15 |
+
wan/models_t5_umt5-xxl-enc-bf16.pth
|
16 |
+
triton-3.0.0-cp310-cp310-win_amd64.whl
|
17 |
+
wan/Wan2.1_VAE.pth
|
18 |
+
flash_attn-2.7.4+cu124torch2.5.1cxx11abiFALSE-cp310-cp310-win_amd64.whl
|
.python-version
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
3.10
|
README.ja.md
ADDED
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Musubi Tuner
|
2 |
+
|
3 |
+
[English](./README.md) | [日本語](./README.ja.md)
|
4 |
+
|
5 |
+
## 目次
|
6 |
+
|
7 |
+
- [はじめに](#はじめに)
|
8 |
+
- [最近の更新](#最近の更新)
|
9 |
+
- [リリースについて](#リリースについて)
|
10 |
+
- [概要](#概要)
|
11 |
+
- [ハードウェア要件](#ハードウェア要件)
|
12 |
+
- [特徴](#特徴)
|
13 |
+
- [インストール](#インストール)
|
14 |
+
- [モデルのダウンロード](#モデルのダウンロード)
|
15 |
+
- [HunyuanVideoの公式モデルを使う](#HunyuanVideoの公式モデルを使う)
|
16 |
+
- [Text EncoderにComfyUI提供のモデルを使う](#Text-EncoderにComfyUI提供のモデルを使う)
|
17 |
+
- [使い方](#使い方)
|
18 |
+
- [データセット設定](#データセット設定)
|
19 |
+
- [latentの事前キャッシュ](#latentの事前キャッシュ)
|
20 |
+
- [Text Encoder出力の事前キャッシュ](#Text-Encoder出力の事前キャッシュ)
|
21 |
+
- [学習](#学習)
|
22 |
+
- [LoRAの重みのマージ](#LoRAの重みのマージ)
|
23 |
+
- [推論](#推論)
|
24 |
+
- [LoRAの形式の変換](#LoRAの形式の変換)
|
25 |
+
- [その他](#その他)
|
26 |
+
- [SageAttentionのインストール方法](#SageAttentionのインストール方法)
|
27 |
+
- [免責事項](#免責事項)
|
28 |
+
- [コントリビューションについて](#コントリビューションについて)
|
29 |
+
- [ライセンス](#ライセンス)
|
30 |
+
|
31 |
+
## はじめに
|
32 |
+
|
33 |
+
このリポジトリは、HunyuanVideoのLoRA学習用のコマンドラインツールです。このリポジトリは非公式であり、公式のHunyuanVideoリポジトリとは関係ありません。
|
34 |
+
|
35 |
+
*リポジトリは開発中です。*
|
36 |
+
|
37 |
+
### 最近の更新
|
38 |
+
|
39 |
+
- 2025/01/20
|
40 |
+
- uv によるインストール手順を試験的に追加しました。PR [#51](https://github.com/kohya-ss/musubi-tuner/pull/51) bmaltais 氏に感謝いたします。ただ、設定等は詰められていないため、フィードバックを歓迎します。
|
41 |
+
- 高度な設定に、[TensorBoard形式のログの保存と参照](./docs/advanced_config.md#save-and-view-logs-in-tensorboard-format--tensorboard形式のログの保存と参照)を追加しました。
|
42 |
+
|
43 |
+
- 2025/01/19
|
44 |
+
- latentとText Encoder出力の事前キャッシュ時に、データセットに含まれないキャッシュファイルを自動で消去するようにしました。これにより予期しないファイルが残り、学習に使用されてしまう問題が解消されます。
|
45 |
+
- `--keep_cache`で今までと同様にキャッシュファイルを残すことができます。
|
46 |
+
- Text Encoder出力の事前キャッシュ時に、`--skip_existing`を指定すると正しく動作しない問題を修正しました。
|
47 |
+
|
48 |
+
- 2025/01/18
|
49 |
+
- `hv_generate_video.py`でvideo2videoの推論が可能になりました。詳細は[推論](#推論)を参照してください。
|
50 |
+
|
51 |
+
- 2025/01/16
|
52 |
+
- LoRAの重みをマージするスクリプト、`merge_lora.py`が追加されました。PR [#37](https://github.com/kohya-ss/musubi-tuner/pull/37) kaykyr氏に感謝いたします。詳細は[LoRAの重みのマージ](#LoRAの重みのマージ)を参照してください。
|
53 |
+
- サンプルの学習設定を、学習率を2e-4に、`--timestep_sampling`を`shift`に、`--discrete_flow_shift`を7.0に変更しました。より高速な学習が期待されます。詳細は[学習](#学習)を参照してください。
|
54 |
+
|
55 |
+
- 2025/01/14
|
56 |
+
- `hv_generate_video.py`に、LoRAマージ後のDiTモデルを保存する`--save_merged_model`オプションを暫定的に追加しました。詳細は[推論](#推論)を参照してください。
|
57 |
+
|
58 |
+
- 2025/01/13
|
59 |
+
- 学習中のサンプル画像(動画)がぼやける現象に対応するため、サンプル生成時の設定を変更しました。詳細は[こちら](./docs/sampling_during_training.md)をご参照ください。
|
60 |
+
- 推論時にdiscrete flow shiftとguidance scaleを正しく設定する必要がありますが、学習時の設定がそのまま使われていたため、この事象が発生していました。デフォルト値を設定したため、改善されると思われます。また`--fs`でdiscrete flow shiftを、`--g`でguidance scaleを指定できます。
|
61 |
+
|
62 |
+
### リリースについて
|
63 |
+
|
64 |
+
Musubi Tunerの解説記事執筆や、関連ツールの開発に取り組んでくださる方々に感謝いたします。このプロジェクトは開発中のため、互換性のない変更や機能追加が起きる可能性があります。想定外の互換性問題を避けるため、参照用として[リリース](https://github.com/kohya-ss/musubi-tuner/releases)をお使いください。
|
65 |
+
|
66 |
+
最新のリリースとバージョン履歴は[リリースページ](https://github.com/kohya-ss/musubi-tuner/releases)で確認できます。
|
67 |
+
|
68 |
+
## 概要
|
69 |
+
|
70 |
+
### ハードウェア要件
|
71 |
+
|
72 |
+
- VRAM: 静止画での学習は12GB以上推奨、動画での学習は24GB以上推奨。
|
73 |
+
- *解像度等の学習設定により異なります。*12GB��は解像度 960x544 以下とし、`--blocks_to_swap`、`--fp8_llm`等の省メモリオプションを使用してください。
|
74 |
+
- メインメモリ: 64GB以上を推奨、32GB+スワップで動作するかもしれませんが、未検証です。
|
75 |
+
|
76 |
+
### 特徴
|
77 |
+
|
78 |
+
- 省メモリに特化
|
79 |
+
- Windows対応(Linuxでの動作報告もあります)
|
80 |
+
- マルチGPUには対応していません
|
81 |
+
|
82 |
+
## インストール
|
83 |
+
|
84 |
+
### pipによるインストール
|
85 |
+
|
86 |
+
Python 3.10以上を使用してください(3.10で動作確認済み)。
|
87 |
+
|
88 |
+
適当な仮想環境を作成し、ご利用のCUDAバージョンに合わせたPyTorchとtorchvisionをインストールしてください。
|
89 |
+
|
90 |
+
PyTorchはバージョン2.5.1以上を使用してください([補足](#PyTorchのバージョンについて))。
|
91 |
+
|
92 |
+
```bash
|
93 |
+
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu124
|
94 |
+
```
|
95 |
+
|
96 |
+
以下のコマンドを使用して、必要な依存関係をインストールします。
|
97 |
+
|
98 |
+
```bash
|
99 |
+
pip install -r requirements.txt
|
100 |
+
```
|
101 |
+
|
102 |
+
オプションとして、FlashAttention、SageAttention(推論にのみ使用、インストール方法は[こちら](#SageAttentionのインストール方法)を参照)を使用できます。
|
103 |
+
|
104 |
+
また、`ascii-magic`(データセットの確認に使用)、`matplotlib`(timestepsの可視化に使用)、`tensorboard`(学習ログの記録に使用)を必要に応じてインストールしてください。
|
105 |
+
|
106 |
+
```bash
|
107 |
+
pip install ascii-magic matplotlib tensorboard
|
108 |
+
```
|
109 |
+
### uvによるインストール
|
110 |
+
|
111 |
+
uvを使用してインストールすることもできますが、uvによるインストールは試験的なものです。フィードバックを歓迎します。
|
112 |
+
|
113 |
+
#### Linux/MacOS
|
114 |
+
|
115 |
+
```sh
|
116 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh
|
117 |
+
```
|
118 |
+
|
119 |
+
表示される指示に従い、pathを設定してください。
|
120 |
+
|
121 |
+
#### Windows
|
122 |
+
|
123 |
+
```powershell
|
124 |
+
powershell -c "irm https://astral.sh/uv/install.ps1 | iex"
|
125 |
+
```
|
126 |
+
|
127 |
+
表示される指示に従い、PATHを設定するか、この時点でシステムを再起動してください。
|
128 |
+
|
129 |
+
## モデルのダウンロード
|
130 |
+
|
131 |
+
以下のいずれかの方法で、モデルをダウンロードしてください。
|
132 |
+
|
133 |
+
### HunyuanVideoの公式モデルを使う
|
134 |
+
|
135 |
+
[公式のREADME](https://github.com/Tencent/HunyuanVideo/blob/main/ckpts/README.md)を参考にダウンロードし、任意のディレクトリに以下のように配置します。
|
136 |
+
|
137 |
+
```
|
138 |
+
ckpts
|
139 |
+
├──hunyuan-video-t2v-720p
|
140 |
+
│ ├──transformers
|
141 |
+
│ ├──vae
|
142 |
+
├──text_encoder
|
143 |
+
├──text_encoder_2
|
144 |
+
├──...
|
145 |
+
```
|
146 |
+
|
147 |
+
### Text EncoderにComfyUI提供のモデルを使う
|
148 |
+
|
149 |
+
こちらの方法の方がより簡単です。DiTとVAEのモデルはHumyuanVideoのものを使用します。
|
150 |
+
|
151 |
+
https://huggingface.co/tencent/HunyuanVideo/tree/main/hunyuan-video-t2v-720p/transformers から、[mp_rank_00_model_states.pt](https://huggingface.co/tencent/HunyuanVideo/resolve/main/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt) をダウンロードし、任意のディレクトリに配置します。
|
152 |
+
|
153 |
+
(同じページにfp8のモデルもありますが、未検証です。)
|
154 |
+
|
155 |
+
`--fp8_base`を指定して学習する場合は、`mp_rank_00_model_states.pt`の代わりに、[こちら](https://huggingface.co/kohya-ss/HunyuanVideo-fp8_e4m3fn-unofficial)の`mp_rank_00_model_states_fp8.safetensors`を使用可能です。(このファイルは非公式のもので、重みを単純にfloat8_e4m3fnに変換したものです。)
|
156 |
+
|
157 |
+
また、https://huggingface.co/tencent/HunyuanVideo/tree/main/hunyuan-video-t2v-720p/vae から、[pytorch_model.pt](https://huggingface.co/tencent/HunyuanVideo/resolve/main/hunyuan-video-t2v-720p/vae/pytorch_model.pt) をダウンロードし、任意のディレクトリに配置します。
|
158 |
+
|
159 |
+
Text EncoderにはComfyUI提供のモデルを使用させていただきます。[ComyUIのページ](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)を参考に、https://huggingface.co/Comfy-Org/HunyuanVideo_repackaged/tree/main/split_files/text_encoders から、llava_llama3_fp16.safetensors (Text Encoder 1、LLM)と、clip_l.safetensors (Text Encoder 2、CLIP)をダウンロードし、任意のディレクトリに配置します。
|
160 |
+
|
161 |
+
(同じページにfp8のLLMモデルもありますが、動作未検証です。)
|
162 |
+
|
163 |
+
## 使い方
|
164 |
+
|
165 |
+
### データセット設定
|
166 |
+
|
167 |
+
[こちら](./dataset/dataset_config.md)を参照してください。
|
168 |
+
|
169 |
+
### latentの事前キャッシュ
|
170 |
+
|
171 |
+
latentの事前キャッシュは必須です。以下のコマンドを使用して、事前キャッシュを作成してください。(pipによるインストールの場合)
|
172 |
+
|
173 |
+
```bash
|
174 |
+
python cache_latents.py --dataset_config path/to/toml --vae path/to/ckpts/hunyuan-video-t2v-720p/vae/pytorch_model.pt --vae_chunk_size 32 --vae_tiling
|
175 |
+
```
|
176 |
+
|
177 |
+
uvでインストールした場合は、`uv run python cache_latents.py ...`のように、`uv run`を先頭につけてください。以下のコマンドも同様です。
|
178 |
+
|
179 |
+
その他のオプションは`python cache_latents.py --help`で確認できます。
|
180 |
+
|
181 |
+
VRAMが足りない場合は、`--vae_spatial_tile_sample_min_size`を128程度に減らし、`--batch_size`を小さくしてください。
|
182 |
+
|
183 |
+
`--debug_mode image` を指定するとデータセットの画像とキャプションが新規ウィンドウに表示されます。`--debug_mode console`でコンソールに表示されます(`ascii-magic`が必要)。
|
184 |
+
|
185 |
+
デフォルトではデータセットに含まれないキャッシュファイルは自動的に削除されます。`--keep_cache`を指定すると、キャッシュファイルを残すことができます。
|
186 |
+
|
187 |
+
### Text Encoder出力の事前キャッシュ
|
188 |
+
|
189 |
+
Text Encoder出力の事前キャッシュは必須です。以下のコマンドを使用して、事前キャッシュを作成してください。
|
190 |
+
|
191 |
+
```bash
|
192 |
+
python cache_text_encoder_outputs.py --dataset_config path/to/toml --text_encoder1 path/to/ckpts/text_encoder --text_encoder2 path/to/ckpts/text_encoder_2 --batch_size 16
|
193 |
+
```
|
194 |
+
|
195 |
+
その他のオプションは`python cache_text_encoder_outputs.py --help`で確認できます。
|
196 |
+
|
197 |
+
`--batch_size`はVRAMに合わせて調整してください。
|
198 |
+
|
199 |
+
VRAMが足りない場合(16GB程度未満の場合)は、`--fp8_llm`を指定して、fp8でLLMを実行してください。
|
200 |
+
|
201 |
+
デフォルトではデータセットに含まれないキャッシュファイルは自動的に削除されます。`--keep_cache`を指定すると、キャッシュファイルを残すことができます。
|
202 |
+
|
203 |
+
### 学習
|
204 |
+
|
205 |
+
以下のコマンドを使用して、学習を開始します(実際には一行で入力してください)。
|
206 |
+
|
207 |
+
```bash
|
208 |
+
accelerate launch --num_cpu_threads_per_process 1 --mixed_precision bf16 hv_train_network.py
|
209 |
+
--dit path/to/ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt
|
210 |
+
--dataset_config path/to/toml --sdpa --mixed_precision bf16 --fp8_base
|
211 |
+
--optimizer_type adamw8bit --learning_rate 2e-4 --gradient_checkpointing
|
212 |
+
--max_data_loader_n_workers 2 --persistent_data_loader_workers
|
213 |
+
--network_module networks.lora --network_dim 32
|
214 |
+
--timestep_sampling shift --discrete_flow_shift 7.0
|
215 |
+
--max_train_epochs 16 --save_every_n_epochs 1 --seed 42
|
216 |
+
--output_dir path/to/output_dir --output_name name-of-lora
|
217 |
+
```
|
218 |
+
|
219 |
+
__更新__:サンプルの学習率を1e-3から2e-4に、`--timestep_sampling`を`sigmoid`から`shift`に、`--discrete_flow_shift`を1.0から7.0に変更しました。より高速な学習が期待されます。ディテールが甘くなる場合は、discrete flow shiftを3.0程度に下げてみてください。
|
220 |
+
|
221 |
+
ただ、適切な学習率、学習ステップ数、timestepsの分布、loss weightingなどのパラメータは、以前として不明な点が数多くあります。情報提供をお待ちしています。
|
222 |
+
|
223 |
+
その他のオプションは`python hv_train_network.py --help`で確認できます(ただし多くのオプションは動作未確認です)。
|
224 |
+
|
225 |
+
`--fp8_base`を指定すると、DiTがfp8で学習されます。未指定時はmixed precisionのデータ型が使用されます。fp8は大きく消費メモリを削減できますが、品質は低下する可能性があります。`--fp8_base`を指定しない場合はVRAM 24GB以上を推奨します。また必要に応じて`--blocks_to_swap`を使用してください。
|
226 |
+
|
227 |
+
VRAMが足りない場合は、`--blocks_to_swap`を指定して、一部のブロックをCPUにオフロードしてください。最大36が指定できます。
|
228 |
+
|
229 |
+
(block swapのアイデアは2kpr氏の実装に基づくものです。2kpr氏にあらためて感謝します。)
|
230 |
+
|
231 |
+
`--sdpa`でPyTorchのscaled dot product attentionを使用します。`--flash_attn`で[FlashAttention]:(https://github.com/Dao-AILab/flash-attention)を使用します。`--xformers`でxformersの利用も可能ですが、xformersを使う場合は`--split_attn`を指定してください。`--sage_attn`でSageAttentionを使用しますが、SageAttentionは現時点では学習に未対応のため、正しく動作しません。
|
232 |
+
|
233 |
+
`--split_attn`を指定すると、attentionを分割して処理します。速度が多少低下しますが、VRAM使用量はわずかに減ります。
|
234 |
+
|
235 |
+
学習されるLoRAの形式は、`sd-scripts`と同じです。
|
236 |
+
|
237 |
+
`--show_timesteps`に`image`(`matplotlib`が必要)または`console`を指定すると、学習時のtimestepsの分布とtimestepsごとのloss weightingが確認できます。
|
238 |
+
|
239 |
+
学習時のログの記録が可能です。[TensorBoard形式のログの保存と参照](./docs/advanced_config.md#save-and-view-logs-in-tensorboard-format--tensorboard形式のログの保存と参照)を参照してください。
|
240 |
+
|
241 |
+
学習中のサンプル画像生成については、[こちらのドキュメント](./docs/sampling_during_training.md)を参照してください。���の他の高度な設定については[こちらのドキュメント](./docs/advanced_config.md)を参照してください。
|
242 |
+
|
243 |
+
### LoRAの重みのマージ
|
244 |
+
|
245 |
+
```bash
|
246 |
+
python merge_lora.py \
|
247 |
+
--dit path/to/ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt \
|
248 |
+
--lora_weight path/to/lora.safetensors \
|
249 |
+
--save_merged_model path/to/merged_model.safetensors \
|
250 |
+
--device cpu \
|
251 |
+
--lora_multiplier 1.0
|
252 |
+
```
|
253 |
+
|
254 |
+
`--device`には計算を行うデバイス(`cpu`または`cuda`等)を指定してください。`cuda`を指定すると計算が高速化されます。
|
255 |
+
|
256 |
+
`--lora_weight`にはマージするLoRAの重みを、`--lora_multiplier`にはLoRAの重みの係数を、それぞれ指定してください。複数個が指定可能で、両者の数は一致させてください。
|
257 |
+
|
258 |
+
### 推論
|
259 |
+
|
260 |
+
以下のコマンドを使用して動画を生成します。
|
261 |
+
|
262 |
+
```bash
|
263 |
+
python hv_generate_video.py --fp8 --video_size 544 960 --video_length 5 --infer_steps 30
|
264 |
+
--prompt "A cat walks on the grass, realistic style." --save_path path/to/save/dir --output_type both
|
265 |
+
--dit path/to/ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt --attn_mode sdpa --split_attn
|
266 |
+
--vae path/to/ckpts/hunyuan-video-t2v-720p/vae/pytorch_model.pt
|
267 |
+
--vae_chunk_size 32 --vae_spatial_tile_sample_min_size 128
|
268 |
+
--text_encoder1 path/to/ckpts/text_encoder
|
269 |
+
--text_encoder2 path/to/ckpts/text_encoder_2
|
270 |
+
--seed 1234 --lora_multiplier 1.0 --lora_weight path/to/lora.safetensors
|
271 |
+
```
|
272 |
+
|
273 |
+
その他のオプションは`python hv_generate_video.py --help`で確認できます。
|
274 |
+
|
275 |
+
`--fp8`を指定すると、DiTがfp8で推論されます。fp8は大きく消費メモリを削減できますが、品質は低下する可能性があります。
|
276 |
+
|
277 |
+
VRAMが足りない場合は、`--blocks_to_swap`を指定して、一部のブロックをCPUにオフロードしてください。最大38が指定できます。
|
278 |
+
|
279 |
+
`--attn_mode`には`flash`、`torch`、`sageattn`、`xformers`または`sdpa`(`torch`指定時と同じ)のいずれかを指定してください。それぞれFlashAttention、scaled dot product attention、SageAttention、xformersに対応します。デフォルトは`torch`です。SageAttentionはVRAMの削減に有効です。
|
280 |
+
|
281 |
+
`--split_attn`を指定すると、attentionを分割して処理します。SageAttention利用時で10%程度の高速化が見込まれます。
|
282 |
+
|
283 |
+
`--output_type`には`both`、`latent`、`video`、`images`のいずれかを指定してください。`both`はlatentと動画の両方を出力します。VAEでOut of Memoryエラーが発生する場合に備えて、`both`を指定することをお勧めします。`--latent_path`に保存されたlatentを指定し、`--output_type video` (または`images`)としてスクリプトを実行すると、VAEのdecodeのみを行えます。
|
284 |
+
|
285 |
+
`--seed`は省略可能です。指定しない場合はランダムなシードが使用されます。
|
286 |
+
|
287 |
+
`--video_length`は「4の倍数+1」を指定してください。
|
288 |
+
|
289 |
+
`--flow_shift`にタイムステップのシフト値(discrete flow shift)を指定可能です。省略時のデフォルト値は7.0で、これは推論ステップ数が50の時の推奨値です。HunyuanVideoの論文では、ステップ数50の場合は7.0、ステップ数20未満(10など)で17.0が推奨されています。
|
290 |
+
|
291 |
+
`--video_path`に読み込む動画を指定すると、video2videoの推論が可能です。動画ファイルを指定するか、複数の画像ファイルが入ったディレクトリを指定してください(画像ファイルはファイル名でソートされ、各フレームとして用いられます)。`--video_length`よりも短い動画を指定するとエラーになります。`--strength`で強度を指定できます。0~1.0で指定でき、大きいほど元の動画からの変化が大きくなります。
|
292 |
+
|
293 |
+
なおvideo2video推論の処理は実験的なものです。
|
294 |
+
|
295 |
+
`--save_merged_model`オプションで、LoRAマージ後のDiTモデルを保存できます。`--save_merged_model path/to/merged_model.safetensors`のように指定してください。なおこのオプションを指定すると推論は行われません。
|
296 |
+
|
297 |
+
### LoRAの形式の変換
|
298 |
+
|
299 |
+
ComfyUIで使用可能な形式(Diffusion-pipeと思われる)への変換は以下のコマンドで行えます。
|
300 |
+
|
301 |
+
```bash
|
302 |
+
python convert_lora.py --input path/to/musubi_lora.safetensors --output path/to/another_format.safetensors --target other
|
303 |
+
```
|
304 |
+
|
305 |
+
`--input`と`--output`はそれぞれ入力と出力のファイルパスを指定してください。
|
306 |
+
|
307 |
+
`--target`には`other`を指定してください。`default`を指定すると、他の形式から当リポジトリの形式に変換できます。
|
308 |
+
|
309 |
+
## その他
|
310 |
+
|
311 |
+
### SageAttentionのインストール方法
|
312 |
+
|
313 |
+
sdbds氏によるWindows対応のSageAttentionのwheelが https://github.com/sdbds/SageAttention-for-windows で公開されています��triton をインストールし、Python、PyTorch、CUDAのバージョンが一致する場合は、[Releases](https://github.com/sdbds/SageAttention-for-windows/releases)からビルド済みwheelをダウンロードしてインストールすることが可能です。sdbds氏に感謝します。
|
314 |
+
|
315 |
+
参考までに、以下は、SageAttentionをビルドしインストールするための簡単な手順です。Microsoft Visual C++ 再頒布可能パッケージを最新にする必要があるかもしれません。
|
316 |
+
|
317 |
+
1. Pythonのバージョンに応じたtriton 3.1.0のwhellを[こちら](https://github.com/woct0rdho/triton-windows/releases/tag/v3.1.0-windows.post5)からダウンロードしてインストールします。
|
318 |
+
|
319 |
+
2. Microsoft Visual Studio 2022かBuild Tools for Visual Studio 2022を、C++のビルドができるよう設定し、インストールします。(上のRedditの投稿を参照してください)。
|
320 |
+
|
321 |
+
3. 任意のフォルダにSageAttentionのリポジトリをクローンします。
|
322 |
+
```shell
|
323 |
+
git clone https://github.com/thu-ml/SageAttention.git
|
324 |
+
```
|
325 |
+
|
326 |
+
なお `git clone https://github.com/sdbds/SageAttention-for-windows.git` で、前述のsdbds氏のリポジトリを使用することで、手順4.を省略できます。
|
327 |
+
|
328 |
+
4. `SageAttention/csrc`フォルダ内の`math.cuh`を開き、71行目と146行目の `ushort` を `unsigned short` に変更して保存します。
|
329 |
+
|
330 |
+
5. スタートメニューから Visual Studio 2022 内の `x64 Native Tools Command Prompt for VS 2022` を選択してコマンドプロンプトを開きます。
|
331 |
+
|
332 |
+
6. venvを有効にし、SageAttentionのフォルダに移動して以下のコマンドを実行します。DISTUTILSが設定されていない、のようなエラーが出た場合は `set DISTUTILS_USE_SDK=1`としてから再度実行してください。
|
333 |
+
```shell
|
334 |
+
python setup.py install
|
335 |
+
```
|
336 |
+
|
337 |
+
以上でSageAttentionのインストールが完了です。
|
338 |
+
|
339 |
+
### PyTorchのバージョンについて
|
340 |
+
|
341 |
+
`--attn_mode`に`torch`を指定する場合、2.5.1以降のPyTorchを使用してください(それより前のバージョンでは生成される動画が真っ黒になるようです)。
|
342 |
+
|
343 |
+
古いバージョンを使う場合、xformersやSageAttentionを使用してください。
|
344 |
+
|
345 |
+
## 免責事項
|
346 |
+
|
347 |
+
このリポジトリは非公式であり、公式のHunyuanVideoリポジトリとは関係ありません。また、このリポジトリは開発中で、実験的なものです。テストおよびフィードバックを歓迎しますが、以下の点にご注意ください:
|
348 |
+
|
349 |
+
- 実際の稼働環境での動作を意図したものではありません
|
350 |
+
- 機能やAPIは予告なく変更されることがあります
|
351 |
+
- いくつもの機能が未検証です
|
352 |
+
- 動画学習機能はまだ開発中です
|
353 |
+
|
354 |
+
問題やバグについては、以下の情報とともにIssueを作成してください:
|
355 |
+
|
356 |
+
- 問題の詳細な説明
|
357 |
+
- 再現手順
|
358 |
+
- 環境の詳細(OS、GPU、VRAM、Pythonバージョンなど)
|
359 |
+
- 関連するエラーメッセージやログ
|
360 |
+
|
361 |
+
## コントリビューションについて
|
362 |
+
|
363 |
+
コントリビューションを歓迎します。ただし、以下にご注意ください:
|
364 |
+
|
365 |
+
- メンテナーのリソースが限られているため、PRのレビューやマージには時間がかかる場合があります
|
366 |
+
- 大きな変更に取り組む前には、議論のためのIssueを作成してください
|
367 |
+
- PRに関して:
|
368 |
+
- 変更は焦点を絞り、適度なサイズにしてください
|
369 |
+
- 明確な説明をお願いします
|
370 |
+
- 既存のコードスタイルに従ってください
|
371 |
+
- ドキュメントが更新されていることを確認してください
|
372 |
+
|
373 |
+
## ライセンス
|
374 |
+
|
375 |
+
`hunyuan_model`ディレクトリ以下のコードは、[HunyuanVideo](https://github.com/Tencent/HunyuanVideo)のコードを一部改変して使用しているため、そちらのライセンスに従います。
|
376 |
+
|
377 |
+
他のコードはApache License 2.0に従います。一部Diffusersのコードをコピー、改変して使用しています。
|
README.md
CHANGED
@@ -1,12 +1,269 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+

|
2 |
+
|
3 |
+
# Recent update
|
4 |
+
5/25/2025
|
5 |
+
Enable full intermediate previews for framepack tab, some change to framepack extension with image input logic.
|
6 |
+
5/24/2025
|
7 |
+
Batch images from folder now available for framepack. Save only extension option and experimental start image for f1 in framepack extension tab.
|
8 |
+
5/23/2025
|
9 |
+
Added ability to use the normal framepack model with endframe support in the framepack extension tab. Support additional bucket sizes.
|
10 |
+
5/18/2025
|
11 |
+
Add video extension tab. Currently only works with f1 model. Full credit to @pfqt and @chaojie for their amazing work!
|
12 |
+
|
13 |
+
# H1111
|
14 |
+
|
15 |
+
This is a GUI for tech wizard kohya-ss's musubi tuner's inference script.
|
16 |
+
https://github.com/kohya-ss/musubi-tuner
|
17 |
+
|
18 |
+
It allows inference with these models:
|
19 |
+
FramePack
|
20 |
+
Hunyuan-t2v
|
21 |
+
Hunyuan-i2v
|
22 |
+
Hunyuan-v2v
|
23 |
+
WanX-t2v
|
24 |
+
WanX-i2v
|
25 |
+
WanX-v2v
|
26 |
+
SkyReels-i2v
|
27 |
+
SkyReels-t2v
|
28 |
+
|
29 |
+
I have mostly been workiing on the framepack tab and the WanX-i2v tab. They are the best to use right now. WanX-i2v is used for skyreels v2 and the fun control models.
|
30 |
+
|
31 |
+
This supports queuing multiple different jobs if you open 2+ browser tabs and use the same model.
|
32 |
+
|
33 |
+
If you are running out of vram use more block swapping. Using FP8 scaled is also a decent option to lower memory usage, select fp8 and fp8 scaled to use it. Scaled fp8 tries to duplicate the important parts of the model from FP16. Sage attention is the fastest/lowest vram but difficult to install in windows.
|
34 |
+
|
35 |
+
Best quality will be obtained with only enabling block swapping and using the fp16 model with sdpa attention. You can speed things up with cfg skip, fp8 scaled, slg skip is small speedup, sage attention is fastest but all speedups come with quality degradations. I designed this to try to focus on quality over speed.
|
36 |
+
|
37 |
+
If you are using a lora that you didn't train with musubi you need to drag it to the convert lora tab and convert it to the default format. It should spit it out into the /lora folder.
|
38 |
+
|
39 |
+
If you need additional installation instructions or information create an issue and I will try to help. Also there are alot of settings notes on the musubi github linked above.
|
40 |
+
|
41 |
+
For torch 2.7.0 and windows installation try:
|
42 |
+
pip install typing-extensions
|
43 |
+
pip install torch==2.7.0+cu128 torchvision==0.22.0+cu128 --index-url https://download.pytorch.org/whl/cu128
|
44 |
+
pip install -r requirementsTorch27.txt
|
45 |
+
|
46 |
+
## To Use FramePack
|
47 |
+
|
48 |
+
|
49 |
+
download these 5 files from https://huggingface.co/maybleMyers/framepack_h1111 and put them in a subfolder named hunyuan (H1111/hunyuan), or reference where they are in the gui if you have already aquired them.
|
50 |
+
|
51 |
+
FramePackI2V_HY_bf16.safetensors or FramePack_F1_I2V_HY_20250503.safetensors for F1
|
52 |
+
|
53 |
+
clip_l.safetensors
|
54 |
+
|
55 |
+
llava_llama3_fp16.safetensors
|
56 |
+
|
57 |
+
model.safetensors
|
58 |
+
|
59 |
+
pytorch_model.pt
|
60 |
+
|
61 |
+
Lora trained with musubi tuner's framepack training confirmed to work great. Normal lora trained for hunyuan kinda suck. Use a lot of block swap this is a different back end than the official repo. If you select fp8 and fp8 scaled it will all fit on a 24gb gpu for fastest speed, about 3s/it or 1:17 per second of video w/ a 4090. Best quality will be obtained with just block swapping/sdpa attention/full model though.
|
62 |
+
|
63 |
+
Put loras in a /lora subfolder, if not trained with musubi you need to convert them.
|
64 |
+
|
65 |
+
Only unipc is supported for now. Sage attn is experimental. When using the F1 model not all options available for the original framepack model will work, like endframe and sectional images.
|
66 |
+
|
67 |
+
Here is an example prompt for a 5 second video with 4 sections using sectional prompting, also supports longer videos with indexes ie 0-2 ;;;3-5 etc:
|
68 |
+
|
69 |
+
0:A cinematic video showcases a cute blue penguin wearing sunglasses. The penguin runs quickly into mcdonalds.;;;1:The penguin runs quickly into mcdonalds and jumps up on a table and starts eating his food. The penguin's name is Piplup he is a famous Pokemon actor. The video is a fast action sequence animation showing the penguin running into a mcdonalds an jumping up onto a table.;;;2:The penguin is seated at a table and is enjoying his happy meal. The penguin's name is Piplup he is a famous Pokemon actor. The video is a fast action sequence animation showing the penguin running into a mcdonalds and jumping up onto a table.;;;3:The penguin is seated at a table and is happily enjoying his happy meal. The penguin's name is Piplup he is a famous Pokemon actor. The penguin flexes his huge arm muscles at the end of the video.
|
70 |
+
|
71 |
+
I have added support for 4 sectional images during inference. It works best when the images are close together. Refer to the screen shot for an example of a working 5 second video.
|
72 |
+
|
73 |
+
For more details on using framepack with musubi go here https://github.com/kohya-ss/musubi-tuner/blob/main/docs/framepack.md
|
74 |
+
|
75 |
+
Fastest speed will be achieved with fp8 and fp8 scaled, then you can reduce block swapping to your memory constraints. (leave about 1gb free)
|
76 |
+
|
77 |
+
Framepack Extension tab is still a work in progress.
|
78 |
+
Thanks to @pftq https://github.com/pftq and @chaojie https://github.com/chaojie for their work on the extension logics.
|
79 |
+
|
80 |
+
## To Use the new Skyreels-V2 models
|
81 |
+
|
82 |
+
I have provided these 2 at https://huggingface.co/maybleMyers/wan_files_for_h1111
|
83 |
+
|
84 |
+
SkyReels-V2-I2V-14B-720P-FP16.safetensors
|
85 |
+
SkyReels-V2-I2V-14B-540P-FP16.safetensors
|
86 |
+
|
87 |
+
You can just drop them into the wan folder and use them in the WanX-i2v tab. Skyreels-V2 is a fine tune from Wan2.1.
|
88 |
+
If you have download the kijai variants the will not work because he added extra keys to the model.
|
89 |
+
|
90 |
+
## To Use WanX
|
91 |
+
|
92 |
+
To use wanX download these and toss them in the wan subfolder:
|
93 |
+
Download the T5 `models_t5_umt5-xxl-enc-bf16.pth`, vae `Wan2.1_VAE.pth` and CLIP `models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth` from the following page: https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P/tree/main
|
94 |
+
|
95 |
+
Download the DiT weights from the following page: https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/tree/main/split_files/diffusion_models
|
96 |
+
ie : wan2.1_i2v_720p_14B_fp16.safetensors
|
97 |
+
|
98 |
+
For the fun control option in WanX-i2v I recommend the fp16 weights here: https://huggingface.co/maybleMyers/wan_files_for_h1111/tree/main
|
99 |
+
Wan2.1-Fun-14B-Control_fp16.safetensors
|
100 |
+
|
101 |
+
git pull to update the installation
|
102 |
+
pip install -r requirements.txt
|
103 |
+
|
104 |
+
I have tested the 14B i2v and t2v models so far to be working
|
105 |
+
|
106 |
+
## Requirements
|
107 |
+
|
108 |
+
- Python 3.10
|
109 |
+
- CUDA 12.4
|
110 |
+
|
111 |
+
## Basic Installation (Linux)
|
112 |
+
|
113 |
+
Tested on ubuntu 24
|
114 |
+
|
115 |
+
to update navigate to H1111 and git pull
|
116 |
+
|
117 |
+
```powershell
|
118 |
+
git clone https://github.com/maybleMyers/H1111
|
119 |
+
cd H1111
|
120 |
+
python -m venv env
|
121 |
+
#(if you have another version of python do python3.10 -m venv env after you install it with sudo apt install python3.10 python3.10-venv python3.10-distutils)
|
122 |
+
source env/bin/activate
|
123 |
+
pip install torch==2.5.1 torchvision --index-url https://download.pytorch.org/whl/cu124
|
124 |
+
pip install -r requirements.txt
|
125 |
+
pip install flash-attn --no-build-isolation
|
126 |
+
pip install sageattention==1.0.6
|
127 |
+
might need python3.10-dev as well for sage attention to work
|
128 |
+
|
129 |
+
```
|
130 |
+
|
131 |
+
run with
|
132 |
+
source env/bin/activate
|
133 |
+
python h1111.py
|
134 |
+
|
135 |
+
for GPU1
|
136 |
+
CUDA_VISIBLE_DEVICES=1 python h1111.py
|
137 |
+
|
138 |
+
## Basic Installation (Windows)
|
139 |
+
|
140 |
+
|
141 |
+
|
142 |
+
First, open PowerShell and navigate to your desired installation directory. Then run these commands:
|
143 |
+
|
144 |
+
```powershell
|
145 |
+
git clone https://github.com/maybleMyers/H1111
|
146 |
+
cd H1111
|
147 |
+
python -m venv env
|
148 |
+
./env/scripts/activate
|
149 |
+
pip install torch==2.5.1 torchvision --index-url https://download.pytorch.org/whl/cu124
|
150 |
+
pip install -r requirements.txt
|
151 |
+
|
152 |
+
```
|
153 |
+
|
154 |
+
## To run
|
155 |
+
|
156 |
+
```
|
157 |
+
env/scripts/activate
|
158 |
+
python h1111.py
|
159 |
+
```
|
160 |
+
|
161 |
+
open 127.0.0.1:7860 in a browser
|
162 |
+
|
163 |
+
You can set cuda device to 1,2,3,4,5,6,7 etc in the env once activated in a separate terminal to run unlimited copies at once if you have another gpu.
|
164 |
+
ie for linux on the second gpu: CUDA_VISIBLE_DEVICES=1 python h1111.py
|
165 |
+
|
166 |
+
## full changlog
|
167 |
+
5/25/2025
|
168 |
+
Enable full intermediate previews for framepack tab, some change to framepack extension with image input logic.
|
169 |
+
5/24/2025
|
170 |
+
Batch images from folder now available for framepack. Save only extension option and experimental start image for f1 in framepack extension tab.
|
171 |
+
5/23/2025
|
172 |
+
Added ability to use the normal framepack model with endframe support in the framepack extension tab. Support additional bucket sizes.
|
173 |
+
5/18/2025
|
174 |
+
Add video extension tab. Currently only works with f1 model. Full credit to @pfqt and @chaojie for their amazing work!
|
175 |
+
5/12/2025
|
176 |
+
Add skip button to framepack.
|
177 |
+
5/9/2025
|
178 |
+
Add testing branch for framepack F1 end image, kinda glitchygo https://github.com/maybleMyers/H1111/tree/f1_end
|
179 |
+
5/5/2025
|
180 |
+
Update an experimental hunyuan to framepack convert lora option in the convert lora tab.
|
181 |
+
Add tea cache to frame pack.
|
182 |
+
5/3/2025
|
183 |
+
Add support for framepack F1! download from https://huggingface.co/maybleMyers/wan_files_for_h1111/blob/main/FramePack_F1_I2V_HY_20250503.safetensors put it in your hunyuan folder. You might need to reinstall reqs "pip install -r requirements.txt"
|
184 |
+
Add support for Wan2.1 i2v-14B-FC-1.1. It is a fun control model and is very good. Use it in the WanX-i2v tab and make sure to select the task i2v-14B-FC-1.1 at the bottom of the page. Download the weights from https://huggingface.co/maybleMyers/wan_files_for_h1111
|
185 |
+
4/30/2025
|
186 |
+
Previews for framepack.
|
187 |
+
4/29/2025
|
188 |
+
Add initial preview support to the wanX-i2v tab based. If you want to use them use the preview branch. Thanks to Sarania.
|
189 |
+
Wan2.1-Fun-V1.1-14B-InP-FP16.safetensors is available at https://huggingface.co/maybleMyers/wan_files_for_h1111
|
190 |
+
Fix bug in hunyuan-t2v not loading lora.
|
191 |
+
4/26/2025
|
192 |
+
Add SkyReels-V2-I2V-14B-720P-FP16.safetensors to supported models.
|
193 |
+
Added alot better options for Framepack including working sectional images, Thanks to kohya!
|
194 |
+
4/25/2025
|
195 |
+
Framepack backend updates for better LoRa support for LoRa's trained with musubi tuner. Also better weighting options.
|
196 |
+
4/24/2025
|
197 |
+
Update FramePack backend to musubi backend instead of original. Offers much improved speed and some quality improvements.
|
198 |
+
Add support for torch 2.7.0 + cuda 12.8
|
199 |
+
4/18/2025
|
200 |
+
Add initial support for FramePack. https://github.com/lllyasviel/FramePack
|
201 |
+
4/15/2025
|
202 |
+
Add much improved functionality for the wan fun control model. Added strength imrpovements and dropoff code to choose when to apply the control video. Thanks wordbrew.
|
203 |
+
4/3/2025
|
204 |
+
Add support for hunyuan i2v model. Download the clip vision from https://huggingface.co/maybleMyers/H1111_Hunyuan_i2v And download the official model from hunyuan's website and rename it to mp_rank_00_model_states_i2v.pt https://huggingface.co/tencent/HunyuanVideo-I2V/tree/main/hunyuan-video-i2v-720p/transformers add both to your hunyuan folder.
|
205 |
+
3/29/2025
|
206 |
+
Added support for fun models! download dit from https://huggingface.co/alibaba-pai/Wan2.1-Fun-14B-Control and specify correct task type and dit location. I renamed it from diffusion_pytorch_model to Wan2.1-Fun-14B-control. Works in the normal WanX-i2v tab when you select the control option at the bottom of the page.
|
207 |
+
3/23/2025
|
208 |
+
Added Wanx cfg skip functionality to skip cfg guidance during inference for faster generations but less following of the prompt
|
209 |
+
3/22/2025
|
210 |
+
Added WanX-i2v end frame functionality
|
211 |
+
3/20/2025
|
212 |
+
Added WanX-v2v functionality.
|
213 |
+
3/18/2025
|
214 |
+
Added Skip Layer Guidance for WanX-i2v.
|
215 |
+
3/13/2025
|
216 |
+
Added extend video functionality to WanX-i2v. It kind of works .
|
217 |
+
3/12/2025
|
218 |
+
Added ability to send the last frame of a video to the input in WanX-i2v. Also you can now use this to extend the video. You can do multiple batches at each step and pick the best extended video then generate an even longer one.
|
219 |
+
3/9/2025
|
220 |
+
Added batching ability for a folder full of images in WanX-i2v tab. Added flash attn for windows prebuilt wheel.
|
221 |
+
3/8/2025
|
222 |
+
Added support for wan lora's. Remember to convert them first in the convert lora tab.
|
223 |
+
3/5/2025
|
224 |
+
Added ability to batch a folder of images with skyreels i2v, so you can make a video with every image in a folder.
|
225 |
+
3/2/2025
|
226 |
+
Added initial support for wanX-2.1 Image to Video and Text to Video inference.
|
227 |
+
3/1/2025
|
228 |
+
Added support for Skyreels Video to Video and Text to Video.
|
229 |
+
2/23/2025
|
230 |
+
Added initial support for skyreels-V1 using musubi's skyreel implementation. (thanks sdbds)
|
231 |
+
download models from https://huggingface.co/Kijai/SkyReels-V1-Hunyuan_comfy and add them to your hunyuan folder
|
232 |
+
skyreels_hunyuan_i2v_bf16.safetensors
|
233 |
+
skyreels_hunyuan_t2v_bf16.safetensors
|
234 |
+
|
235 |
+
|
236 |
+
## to use stock hunyuan models
|
237 |
+
|
238 |
+
https://huggingface.co/tencent/HunyuanVideo/resolve/main/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt
|
239 |
+
|
240 |
+
https://huggingface.co/tencent/HunyuanVideo/resolve/main/hunyuan-video-t2v-720p/vae/pytorch_model.pt
|
241 |
+
|
242 |
+
https://huggingface.co/Comfy-Org/HunyuanVideo_repackaged/resolve/main/split_files/text_encoders/llava_llama3_fp16.safetensors
|
243 |
+
|
244 |
+
https://huggingface.co/Comfy-Org/HunyuanVideo_repackaged/resolve/main/split_files/text_encoders/clip_l.safetensors
|
245 |
+
|
246 |
+
#fp8 dit model
|
247 |
+
|
248 |
+
https://huggingface.co/kohya-ss/HunyuanVideo-fp8_e4m3fn-unofficial/resolve/main/mp_rank_00_model_states_fp8.safetensors
|
249 |
+
|
250 |
+
place models in H1111/hunyuan folder
|
251 |
+
|
252 |
+
### Optional: Install Xformers
|
253 |
+
```powershell
|
254 |
+
pip install --no-deps xformers --index-url https://download.pytorch.org/whl/cu124
|
255 |
+
```
|
256 |
+
|
257 |
+
### Optional: Install Flash Attention
|
258 |
+
Note: This can take 1-5 hour to install even on a good CPU, but provides faster generation.
|
259 |
+
I have uploaded a wheel for windows users to match cuda 12.4 and python 3.10.(thanks lldacing)
|
260 |
+
https://huggingface.co/maybleMyers/wan_files_for_h1111/resolve/main/flash_attn-2.7.4%2Bcu124torch2.5.1cxx11abiFALSE-cp310-cp310-win_amd64.whl?download=true
|
261 |
+
|
262 |
+
```powershell
|
263 |
+
pip install flash-attn --no-build-isolation
|
264 |
+
|
265 |
+
If you have downloaded the wheel you can install it with:
|
266 |
+
|
267 |
+
pip install "flash_attn-2.7.4+cu124torch2.5.1cxx11abiFALSE-cp310-cp310-win_amd64.whl"
|
268 |
+
```
|
269 |
+
```
|
base_hv_generate_video.py
ADDED
@@ -0,0 +1,936 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from datetime import datetime
|
3 |
+
from pathlib import Path
|
4 |
+
import random
|
5 |
+
import sys
|
6 |
+
import os
|
7 |
+
import time
|
8 |
+
from typing import Optional, Union
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
import torchvision
|
13 |
+
import accelerate
|
14 |
+
from diffusers.utils.torch_utils import randn_tensor
|
15 |
+
from transformers.models.llama import LlamaModel
|
16 |
+
from tqdm import tqdm
|
17 |
+
import av
|
18 |
+
from einops import rearrange
|
19 |
+
from safetensors.torch import load_file, save_file
|
20 |
+
from safetensors import safe_open
|
21 |
+
from PIL import Image
|
22 |
+
|
23 |
+
from hunyuan_model import vae
|
24 |
+
from hunyuan_model.text_encoder import TextEncoder
|
25 |
+
from hunyuan_model.text_encoder import PROMPT_TEMPLATE
|
26 |
+
from hunyuan_model.vae import load_vae
|
27 |
+
from hunyuan_model.models import load_transformer, get_rotary_pos_embed
|
28 |
+
from hunyuan_model.fp8_optimization import convert_fp8_linear
|
29 |
+
from modules.scheduling_flow_match_discrete import FlowMatchDiscreteScheduler
|
30 |
+
from networks import lora
|
31 |
+
|
32 |
+
try:
|
33 |
+
from lycoris.kohya import create_network_from_weights
|
34 |
+
except:
|
35 |
+
pass
|
36 |
+
|
37 |
+
from utils.model_utils import str_to_dtype
|
38 |
+
from utils.safetensors_utils import mem_eff_save_file
|
39 |
+
from dataset.image_video_dataset import load_video, glob_images, resize_image_to_bucket
|
40 |
+
|
41 |
+
import logging
|
42 |
+
|
43 |
+
logger = logging.getLogger(__name__)
|
44 |
+
logging.basicConfig(level=logging.INFO)
|
45 |
+
|
46 |
+
|
47 |
+
def clean_memory_on_device(device):
|
48 |
+
if device.type == "cuda":
|
49 |
+
torch.cuda.empty_cache()
|
50 |
+
elif device.type == "cpu":
|
51 |
+
pass
|
52 |
+
elif device.type == "mps": # not tested
|
53 |
+
torch.mps.empty_cache()
|
54 |
+
|
55 |
+
|
56 |
+
def synchronize_device(device: torch.device):
|
57 |
+
if device.type == "cuda":
|
58 |
+
torch.cuda.synchronize()
|
59 |
+
elif device.type == "xpu":
|
60 |
+
torch.xpu.synchronize()
|
61 |
+
elif device.type == "mps":
|
62 |
+
torch.mps.synchronize()
|
63 |
+
|
64 |
+
|
65 |
+
def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=1, fps=24):
|
66 |
+
"""save videos by video tensor
|
67 |
+
copy from https://github.com/guoyww/AnimateDiff/blob/e92bd5671ba62c0d774a32951453e328018b7c5b/animatediff/utils/util.py#L61
|
68 |
+
|
69 |
+
Args:
|
70 |
+
videos (torch.Tensor): video tensor predicted by the model
|
71 |
+
path (str): path to save video
|
72 |
+
rescale (bool, optional): rescale the video tensor from [-1, 1] to . Defaults to False.
|
73 |
+
n_rows (int, optional): Defaults to 1.
|
74 |
+
fps (int, optional): video save fps. Defaults to 8.
|
75 |
+
"""
|
76 |
+
videos = rearrange(videos, "b c t h w -> t b c h w")
|
77 |
+
outputs = []
|
78 |
+
for x in videos:
|
79 |
+
x = torchvision.utils.make_grid(x, nrow=n_rows)
|
80 |
+
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
81 |
+
if rescale:
|
82 |
+
x = (x + 1.0) / 2.0 # -1,1 -> 0,1
|
83 |
+
x = torch.clamp(x, 0, 1)
|
84 |
+
x = (x * 255).numpy().astype(np.uint8)
|
85 |
+
outputs.append(x)
|
86 |
+
|
87 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
88 |
+
|
89 |
+
# # save video with av
|
90 |
+
# container = av.open(path, "w")
|
91 |
+
# stream = container.add_stream("libx264", rate=fps)
|
92 |
+
# for x in outputs:
|
93 |
+
# frame = av.VideoFrame.from_ndarray(x, format="rgb24")
|
94 |
+
# packet = stream.encode(frame)
|
95 |
+
# container.mux(packet)
|
96 |
+
# packet = stream.encode(None)
|
97 |
+
# container.mux(packet)
|
98 |
+
# container.close()
|
99 |
+
|
100 |
+
height, width, _ = outputs[0].shape
|
101 |
+
|
102 |
+
# create output container
|
103 |
+
container = av.open(path, mode="w")
|
104 |
+
|
105 |
+
# create video stream
|
106 |
+
codec = "libx264"
|
107 |
+
pixel_format = "yuv420p"
|
108 |
+
stream = container.add_stream(codec, rate=fps)
|
109 |
+
stream.width = width
|
110 |
+
stream.height = height
|
111 |
+
stream.pix_fmt = pixel_format
|
112 |
+
stream.bit_rate = 4000000 # 4Mbit/s
|
113 |
+
|
114 |
+
for frame_array in outputs:
|
115 |
+
frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
|
116 |
+
packets = stream.encode(frame)
|
117 |
+
for packet in packets:
|
118 |
+
container.mux(packet)
|
119 |
+
|
120 |
+
for packet in stream.encode():
|
121 |
+
container.mux(packet)
|
122 |
+
|
123 |
+
container.close()
|
124 |
+
|
125 |
+
|
126 |
+
def save_images_grid(
|
127 |
+
videos: torch.Tensor, parent_dir: str, image_name: str, rescale: bool = False, n_rows: int = 1, create_subdir=True
|
128 |
+
):
|
129 |
+
videos = rearrange(videos, "b c t h w -> t b c h w")
|
130 |
+
outputs = []
|
131 |
+
for x in videos:
|
132 |
+
x = torchvision.utils.make_grid(x, nrow=n_rows)
|
133 |
+
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
134 |
+
if rescale:
|
135 |
+
x = (x + 1.0) / 2.0 # -1,1 -> 0,1
|
136 |
+
x = torch.clamp(x, 0, 1)
|
137 |
+
x = (x * 255).numpy().astype(np.uint8)
|
138 |
+
outputs.append(x)
|
139 |
+
|
140 |
+
if create_subdir:
|
141 |
+
output_dir = os.path.join(parent_dir, image_name)
|
142 |
+
else:
|
143 |
+
output_dir = parent_dir
|
144 |
+
|
145 |
+
os.makedirs(output_dir, exist_ok=True)
|
146 |
+
for i, x in enumerate(outputs):
|
147 |
+
image_path = os.path.join(output_dir, f"{image_name}_{i:03d}.png")
|
148 |
+
image = Image.fromarray(x)
|
149 |
+
image.save(image_path)
|
150 |
+
|
151 |
+
|
152 |
+
# region Encoding prompt
|
153 |
+
|
154 |
+
|
155 |
+
def encode_prompt(prompt: Union[str, list[str]], device: torch.device, num_videos_per_prompt: int, text_encoder: TextEncoder):
|
156 |
+
r"""
|
157 |
+
Encodes the prompt into text encoder hidden states.
|
158 |
+
|
159 |
+
Args:
|
160 |
+
prompt (`str` or `List[str]`):
|
161 |
+
prompt to be encoded
|
162 |
+
device: (`torch.device`):
|
163 |
+
torch device
|
164 |
+
num_videos_per_prompt (`int`):
|
165 |
+
number of videos that should be generated per prompt
|
166 |
+
text_encoder (TextEncoder):
|
167 |
+
text encoder to be used for encoding the prompt
|
168 |
+
"""
|
169 |
+
# LoRA and Textual Inversion are not supported in this script
|
170 |
+
# negative prompt and prompt embedding are not supported in this script
|
171 |
+
# clip_skip is not supported in this script because it is not used in the original script
|
172 |
+
data_type = "video" # video only, image is not supported
|
173 |
+
|
174 |
+
text_inputs = text_encoder.text2tokens(prompt, data_type=data_type)
|
175 |
+
|
176 |
+
with torch.no_grad():
|
177 |
+
prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type, device=device)
|
178 |
+
prompt_embeds = prompt_outputs.hidden_state
|
179 |
+
|
180 |
+
attention_mask = prompt_outputs.attention_mask
|
181 |
+
if attention_mask is not None:
|
182 |
+
attention_mask = attention_mask.to(device)
|
183 |
+
bs_embed, seq_len = attention_mask.shape
|
184 |
+
attention_mask = attention_mask.repeat(1, num_videos_per_prompt)
|
185 |
+
attention_mask = attention_mask.view(bs_embed * num_videos_per_prompt, seq_len)
|
186 |
+
|
187 |
+
prompt_embeds_dtype = text_encoder.dtype
|
188 |
+
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
189 |
+
|
190 |
+
if prompt_embeds.ndim == 2:
|
191 |
+
bs_embed, _ = prompt_embeds.shape
|
192 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
193 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt)
|
194 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, -1)
|
195 |
+
else:
|
196 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
197 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
198 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
199 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)
|
200 |
+
|
201 |
+
return prompt_embeds, attention_mask
|
202 |
+
|
203 |
+
|
204 |
+
def encode_input_prompt(prompt: Union[str, list[str]], args, device, fp8_llm=False, accelerator=None):
|
205 |
+
# constants
|
206 |
+
prompt_template_video = "dit-llm-encode-video"
|
207 |
+
prompt_template = "dit-llm-encode"
|
208 |
+
text_encoder_dtype = torch.float16
|
209 |
+
text_encoder_type = "llm"
|
210 |
+
text_len = 256
|
211 |
+
hidden_state_skip_layer = 2
|
212 |
+
apply_final_norm = False
|
213 |
+
reproduce = False
|
214 |
+
|
215 |
+
text_encoder_2_type = "clipL"
|
216 |
+
text_len_2 = 77
|
217 |
+
|
218 |
+
num_videos = 1
|
219 |
+
|
220 |
+
# if args.prompt_template_video is not None:
|
221 |
+
# crop_start = PROMPT_TEMPLATE[args.prompt_template_video].get("crop_start", 0)
|
222 |
+
# elif args.prompt_template is not None:
|
223 |
+
# crop_start = PROMPT_TEMPLATE[args.prompt_template].get("crop_start", 0)
|
224 |
+
# else:
|
225 |
+
# crop_start = 0
|
226 |
+
crop_start = PROMPT_TEMPLATE[prompt_template_video].get("crop_start", 0)
|
227 |
+
max_length = text_len + crop_start
|
228 |
+
|
229 |
+
# prompt_template
|
230 |
+
prompt_template = PROMPT_TEMPLATE[prompt_template]
|
231 |
+
|
232 |
+
# prompt_template_video
|
233 |
+
prompt_template_video = PROMPT_TEMPLATE[prompt_template_video] # if args.prompt_template_video is not None else None
|
234 |
+
|
235 |
+
# load text encoders
|
236 |
+
logger.info(f"loading text encoder: {args.text_encoder1}")
|
237 |
+
text_encoder = TextEncoder(
|
238 |
+
text_encoder_type=text_encoder_type,
|
239 |
+
max_length=max_length,
|
240 |
+
text_encoder_dtype=text_encoder_dtype,
|
241 |
+
text_encoder_path=args.text_encoder1,
|
242 |
+
tokenizer_type=text_encoder_type,
|
243 |
+
prompt_template=prompt_template,
|
244 |
+
prompt_template_video=prompt_template_video,
|
245 |
+
hidden_state_skip_layer=hidden_state_skip_layer,
|
246 |
+
apply_final_norm=apply_final_norm,
|
247 |
+
reproduce=reproduce,
|
248 |
+
)
|
249 |
+
text_encoder.eval()
|
250 |
+
if fp8_llm:
|
251 |
+
org_dtype = text_encoder.dtype
|
252 |
+
logger.info(f"Moving and casting text encoder to {device} and torch.float8_e4m3fn")
|
253 |
+
text_encoder.to(device=device, dtype=torch.float8_e4m3fn)
|
254 |
+
|
255 |
+
# prepare LLM for fp8
|
256 |
+
def prepare_fp8(llama_model: LlamaModel, target_dtype):
|
257 |
+
def forward_hook(module):
|
258 |
+
def forward(hidden_states):
|
259 |
+
input_dtype = hidden_states.dtype
|
260 |
+
hidden_states = hidden_states.to(torch.float32)
|
261 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
262 |
+
hidden_states = hidden_states * torch.rsqrt(variance + module.variance_epsilon)
|
263 |
+
return module.weight.to(input_dtype) * hidden_states.to(input_dtype)
|
264 |
+
|
265 |
+
return forward
|
266 |
+
|
267 |
+
for module in llama_model.modules():
|
268 |
+
if module.__class__.__name__ in ["Embedding"]:
|
269 |
+
# print("set", module.__class__.__name__, "to", target_dtype)
|
270 |
+
module.to(target_dtype)
|
271 |
+
if module.__class__.__name__ in ["LlamaRMSNorm"]:
|
272 |
+
# print("set", module.__class__.__name__, "hooks")
|
273 |
+
module.forward = forward_hook(module)
|
274 |
+
|
275 |
+
prepare_fp8(text_encoder.model, org_dtype)
|
276 |
+
|
277 |
+
logger.info(f"loading text encoder 2: {args.text_encoder2}")
|
278 |
+
text_encoder_2 = TextEncoder(
|
279 |
+
text_encoder_type=text_encoder_2_type,
|
280 |
+
max_length=text_len_2,
|
281 |
+
text_encoder_dtype=text_encoder_dtype,
|
282 |
+
text_encoder_path=args.text_encoder2,
|
283 |
+
tokenizer_type=text_encoder_2_type,
|
284 |
+
reproduce=reproduce,
|
285 |
+
)
|
286 |
+
text_encoder_2.eval()
|
287 |
+
|
288 |
+
# encode prompt
|
289 |
+
logger.info(f"Encoding prompt with text encoder 1")
|
290 |
+
text_encoder.to(device=device)
|
291 |
+
if fp8_llm:
|
292 |
+
with accelerator.autocast():
|
293 |
+
prompt_embeds, prompt_mask = encode_prompt(prompt, device, num_videos, text_encoder)
|
294 |
+
else:
|
295 |
+
prompt_embeds, prompt_mask = encode_prompt(prompt, device, num_videos, text_encoder)
|
296 |
+
text_encoder = None
|
297 |
+
clean_memory_on_device(device)
|
298 |
+
|
299 |
+
logger.info(f"Encoding prompt with text encoder 2")
|
300 |
+
text_encoder_2.to(device=device)
|
301 |
+
prompt_embeds_2, prompt_mask_2 = encode_prompt(prompt, device, num_videos, text_encoder_2)
|
302 |
+
|
303 |
+
prompt_embeds = prompt_embeds.to("cpu")
|
304 |
+
prompt_mask = prompt_mask.to("cpu")
|
305 |
+
prompt_embeds_2 = prompt_embeds_2.to("cpu")
|
306 |
+
prompt_mask_2 = prompt_mask_2.to("cpu")
|
307 |
+
|
308 |
+
text_encoder_2 = None
|
309 |
+
clean_memory_on_device(device)
|
310 |
+
|
311 |
+
return prompt_embeds, prompt_mask, prompt_embeds_2, prompt_mask_2
|
312 |
+
|
313 |
+
|
314 |
+
# endregion
|
315 |
+
|
316 |
+
|
317 |
+
def prepare_vae(args, device):
|
318 |
+
vae_dtype = torch.float16 if args.vae_dtype is None else str_to_dtype(args.vae_dtype)
|
319 |
+
vae, _, s_ratio, t_ratio = load_vae(vae_dtype=vae_dtype, device=device, vae_path=args.vae)
|
320 |
+
vae.eval()
|
321 |
+
# vae_kwargs = {"s_ratio": s_ratio, "t_ratio": t_ratio}
|
322 |
+
|
323 |
+
# set chunk_size to CausalConv3d recursively
|
324 |
+
chunk_size = args.vae_chunk_size
|
325 |
+
if chunk_size is not None:
|
326 |
+
vae.set_chunk_size_for_causal_conv_3d(chunk_size)
|
327 |
+
logger.info(f"Set chunk_size to {chunk_size} for CausalConv3d")
|
328 |
+
|
329 |
+
if args.vae_spatial_tile_sample_min_size is not None:
|
330 |
+
vae.enable_spatial_tiling(True)
|
331 |
+
vae.tile_sample_min_size = args.vae_spatial_tile_sample_min_size
|
332 |
+
vae.tile_latent_min_size = args.vae_spatial_tile_sample_min_size // 8
|
333 |
+
# elif args.vae_tiling:
|
334 |
+
else:
|
335 |
+
vae.enable_spatial_tiling(True)
|
336 |
+
|
337 |
+
return vae, vae_dtype
|
338 |
+
|
339 |
+
|
340 |
+
def encode_to_latents(args, video, device):
|
341 |
+
vae, vae_dtype = prepare_vae(args, device)
|
342 |
+
|
343 |
+
video = video.to(device=device, dtype=vae_dtype)
|
344 |
+
video = video * 2 - 1 # 0, 1 -> -1, 1
|
345 |
+
with torch.no_grad():
|
346 |
+
latents = vae.encode(video).latent_dist.sample()
|
347 |
+
|
348 |
+
if hasattr(vae.config, "shift_factor") and vae.config.shift_factor:
|
349 |
+
latents = (latents - vae.config.shift_factor) * vae.config.scaling_factor
|
350 |
+
else:
|
351 |
+
latents = latents * vae.config.scaling_factor
|
352 |
+
|
353 |
+
return latents
|
354 |
+
|
355 |
+
|
356 |
+
def decode_latents(args, latents, device):
|
357 |
+
vae, vae_dtype = prepare_vae(args, device)
|
358 |
+
|
359 |
+
expand_temporal_dim = False
|
360 |
+
if len(latents.shape) == 4:
|
361 |
+
latents = latents.unsqueeze(2)
|
362 |
+
expand_temporal_dim = True
|
363 |
+
elif len(latents.shape) == 5:
|
364 |
+
pass
|
365 |
+
else:
|
366 |
+
raise ValueError(f"Only support latents with shape (b, c, h, w) or (b, c, f, h, w), but got {latents.shape}.")
|
367 |
+
|
368 |
+
if hasattr(vae.config, "shift_factor") and vae.config.shift_factor:
|
369 |
+
latents = latents / vae.config.scaling_factor + vae.config.shift_factor
|
370 |
+
else:
|
371 |
+
latents = latents / vae.config.scaling_factor
|
372 |
+
|
373 |
+
latents = latents.to(device=device, dtype=vae_dtype)
|
374 |
+
with torch.no_grad():
|
375 |
+
image = vae.decode(latents, return_dict=False)[0]
|
376 |
+
|
377 |
+
if expand_temporal_dim:
|
378 |
+
image = image.squeeze(2)
|
379 |
+
|
380 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
381 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
382 |
+
image = image.cpu().float()
|
383 |
+
|
384 |
+
return image
|
385 |
+
|
386 |
+
|
387 |
+
def parse_args():
|
388 |
+
parser = argparse.ArgumentParser(description="HunyuanVideo inference script")
|
389 |
+
|
390 |
+
parser.add_argument("--dit", type=str, required=True, help="DiT checkpoint path or directory")
|
391 |
+
parser.add_argument(
|
392 |
+
"--dit_in_channels",
|
393 |
+
type=int,
|
394 |
+
default=None,
|
395 |
+
help="input channels for DiT, default is None (automatically detect). 32 for SkyReels-I2V, 16 for others",
|
396 |
+
)
|
397 |
+
parser.add_argument("--vae", type=str, required=True, help="VAE checkpoint path or directory")
|
398 |
+
parser.add_argument("--vae_dtype", type=str, default=None, help="data type for VAE, default is float16")
|
399 |
+
parser.add_argument("--text_encoder1", type=str, required=True, help="Text Encoder 1 directory")
|
400 |
+
parser.add_argument("--text_encoder2", type=str, required=True, help="Text Encoder 2 directory")
|
401 |
+
|
402 |
+
# LoRA
|
403 |
+
parser.add_argument("--lora_weight", type=str, nargs="*", required=False, default=None, help="LoRA weight path")
|
404 |
+
parser.add_argument("--lora_multiplier", type=float, nargs="*", default=1.0, help="LoRA multiplier")
|
405 |
+
parser.add_argument(
|
406 |
+
"--save_merged_model",
|
407 |
+
type=str,
|
408 |
+
default=None,
|
409 |
+
help="Save merged model to path. If specified, no inference will be performed.",
|
410 |
+
)
|
411 |
+
parser.add_argument("--exclude_single_blocks", action="store_true", help="Exclude single blocks when loading LoRA weights")
|
412 |
+
|
413 |
+
# inference
|
414 |
+
parser.add_argument("--prompt", type=str, required=True, help="prompt for generation")
|
415 |
+
parser.add_argument("--negative_prompt", type=str, default=None, help="negative prompt for generation")
|
416 |
+
parser.add_argument("--video_size", type=int, nargs=2, default=[256, 256], help="video size")
|
417 |
+
parser.add_argument("--video_length", type=int, default=129, help="video length")
|
418 |
+
parser.add_argument("--fps", type=int, default=24, help="video fps")
|
419 |
+
parser.add_argument("--infer_steps", type=int, default=50, help="number of inference steps")
|
420 |
+
parser.add_argument("--save_path", type=str, required=True, help="path to save generated video")
|
421 |
+
parser.add_argument("--seed", type=int, default=None, help="Seed for evaluation.")
|
422 |
+
parser.add_argument(
|
423 |
+
"--guidance_scale",
|
424 |
+
type=float,
|
425 |
+
default=1.0,
|
426 |
+
help="Guidance scale for classifier free guidance. Default is 1.0 (means no guidance)",
|
427 |
+
)
|
428 |
+
parser.add_argument("--embedded_cfg_scale", type=float, default=6.0, help="Embeded classifier free guidance scale.")
|
429 |
+
parser.add_argument("--video_path", type=str, default=None, help="path to video for video2video inference")
|
430 |
+
parser.add_argument(
|
431 |
+
"--image_path", type=str, default=None, help="path to image for image2video inference, only works for SkyReels-I2V model"
|
432 |
+
)
|
433 |
+
parser.add_argument(
|
434 |
+
"--split_uncond",
|
435 |
+
action="store_true",
|
436 |
+
help="split unconditional call for classifier free guidance, slower but less memory usage",
|
437 |
+
)
|
438 |
+
parser.add_argument("--strength", type=float, default=0.8, help="strength for video2video inference")
|
439 |
+
|
440 |
+
# Flow Matching
|
441 |
+
parser.add_argument("--flow_shift", type=float, default=7.0, help="Shift factor for flow matching schedulers.")
|
442 |
+
|
443 |
+
parser.add_argument("--fp8", action="store_true", help="use fp8 for DiT model")
|
444 |
+
parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for Text Encoder 1 (LLM)")
|
445 |
+
parser.add_argument(
|
446 |
+
"--device", type=str, default=None, help="device to use for inference. If None, use CUDA if available, otherwise use CPU"
|
447 |
+
)
|
448 |
+
parser.add_argument(
|
449 |
+
"--attn_mode", type=str, default="torch", choices=["flash", "torch", "sageattn", "xformers", "sdpa"], help="attention mode"
|
450 |
+
)
|
451 |
+
parser.add_argument(
|
452 |
+
"--split_attn", action="store_true", help="use split attention, default is False. if True, --split_uncond becomes True"
|
453 |
+
)
|
454 |
+
parser.add_argument("--vae_chunk_size", type=int, default=None, help="chunk size for CausalConv3d in VAE")
|
455 |
+
parser.add_argument(
|
456 |
+
"--vae_spatial_tile_sample_min_size", type=int, default=None, help="spatial tile sample min size for VAE, default 256"
|
457 |
+
)
|
458 |
+
parser.add_argument("--blocks_to_swap", type=int, default=None, help="number of blocks to swap in the model")
|
459 |
+
parser.add_argument("--img_in_txt_in_offloading", action="store_true", help="offload img_in and txt_in to cpu")
|
460 |
+
parser.add_argument(
|
461 |
+
"--output_type", type=str, default="video", choices=["video", "images", "latent", "both"], help="output type"
|
462 |
+
)
|
463 |
+
parser.add_argument("--no_metadata", action="store_true", help="do not save metadata")
|
464 |
+
parser.add_argument("--latent_path", type=str, nargs="*", default=None, help="path to latent for decode. no inference")
|
465 |
+
parser.add_argument("--lycoris", action="store_true", help="use lycoris for inference")
|
466 |
+
parser.add_argument("--fp8_fast", action="store_true", help="Enable fast FP8 arthimetic(RTX 4XXX+)")
|
467 |
+
parser.add_argument("--compile", action="store_true", help="Enable torch.compile")
|
468 |
+
parser.add_argument(
|
469 |
+
"--compile_args",
|
470 |
+
nargs=4,
|
471 |
+
metavar=("BACKEND", "MODE", "DYNAMIC", "FULLGRAPH"),
|
472 |
+
default=["inductor", "max-autotune-no-cudagraphs", "False", "False"],
|
473 |
+
help="Torch.compile settings",
|
474 |
+
)
|
475 |
+
|
476 |
+
args = parser.parse_args()
|
477 |
+
|
478 |
+
assert (args.latent_path is None or len(args.latent_path) == 0) or (
|
479 |
+
args.output_type == "images" or args.output_type == "video"
|
480 |
+
), "latent_path is only supported for images or video output"
|
481 |
+
|
482 |
+
# update dit_weight based on model_base if not exists
|
483 |
+
|
484 |
+
if args.fp8_fast and not args.fp8:
|
485 |
+
raise ValueError("--fp8_fast requires --fp8")
|
486 |
+
|
487 |
+
return args
|
488 |
+
|
489 |
+
|
490 |
+
def check_inputs(args):
|
491 |
+
height = args.video_size[0]
|
492 |
+
width = args.video_size[1]
|
493 |
+
video_length = args.video_length
|
494 |
+
|
495 |
+
if height % 8 != 0 or width % 8 != 0:
|
496 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
497 |
+
return height, width, video_length
|
498 |
+
|
499 |
+
|
500 |
+
def main():
|
501 |
+
args = parse_args()
|
502 |
+
|
503 |
+
device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
|
504 |
+
device = torch.device(device)
|
505 |
+
dit_dtype = torch.bfloat16
|
506 |
+
dit_weight_dtype = torch.float8_e4m3fn if args.fp8 else dit_dtype
|
507 |
+
logger.info(f"Using device: {device}, DiT precision: {dit_dtype}, weight precision: {dit_weight_dtype}")
|
508 |
+
|
509 |
+
original_base_names = None
|
510 |
+
if args.latent_path is not None and len(args.latent_path) > 0:
|
511 |
+
original_base_names = []
|
512 |
+
latents_list = []
|
513 |
+
seeds = []
|
514 |
+
for latent_path in args.latent_path:
|
515 |
+
original_base_names.append(os.path.splitext(os.path.basename(latent_path))[0])
|
516 |
+
seed = 0
|
517 |
+
|
518 |
+
if os.path.splitext(latent_path)[1] != ".safetensors":
|
519 |
+
latents = torch.load(latent_path, map_location="cpu")
|
520 |
+
else:
|
521 |
+
latents = load_file(latent_path)["latent"]
|
522 |
+
with safe_open(latent_path, framework="pt") as f:
|
523 |
+
metadata = f.metadata()
|
524 |
+
if metadata is None:
|
525 |
+
metadata = {}
|
526 |
+
logger.info(f"Loaded metadata: {metadata}")
|
527 |
+
|
528 |
+
if "seeds" in metadata:
|
529 |
+
seed = int(metadata["seeds"])
|
530 |
+
|
531 |
+
seeds.append(seed)
|
532 |
+
latents_list.append(latents)
|
533 |
+
|
534 |
+
logger.info(f"Loaded latent from {latent_path}. Shape: {latents.shape}")
|
535 |
+
latents = torch.stack(latents_list, dim=0)
|
536 |
+
else:
|
537 |
+
# prepare accelerator
|
538 |
+
mixed_precision = "bf16" if dit_dtype == torch.bfloat16 else "fp16"
|
539 |
+
accelerator = accelerate.Accelerator(mixed_precision=mixed_precision)
|
540 |
+
|
541 |
+
# load prompt
|
542 |
+
prompt = args.prompt # TODO load prompts from file
|
543 |
+
assert prompt is not None, "prompt is required"
|
544 |
+
|
545 |
+
# check inputs: may be height, width, video_length etc will be changed for each generation in future
|
546 |
+
height, width, video_length = check_inputs(args)
|
547 |
+
|
548 |
+
# encode prompt with LLM and Text Encoder
|
549 |
+
logger.info(f"Encoding prompt: {prompt}")
|
550 |
+
|
551 |
+
do_classifier_free_guidance = args.guidance_scale != 1.0
|
552 |
+
if do_classifier_free_guidance:
|
553 |
+
negative_prompt = args.negative_prompt
|
554 |
+
if negative_prompt is None:
|
555 |
+
logger.info("Negative prompt is not provided, using empty prompt")
|
556 |
+
negative_prompt = ""
|
557 |
+
logger.info(f"Encoding negative prompt: {negative_prompt}")
|
558 |
+
prompt = [negative_prompt, prompt]
|
559 |
+
else:
|
560 |
+
if args.negative_prompt is not None:
|
561 |
+
logger.warning("Negative prompt is provided but guidance_scale is 1.0, negative prompt will be ignored.")
|
562 |
+
|
563 |
+
prompt_embeds, prompt_mask, prompt_embeds_2, prompt_mask_2 = encode_input_prompt(
|
564 |
+
prompt, args, device, args.fp8_llm, accelerator
|
565 |
+
)
|
566 |
+
|
567 |
+
# encode latents for video2video inference
|
568 |
+
video_latents = None
|
569 |
+
if args.video_path is not None:
|
570 |
+
# v2v inference
|
571 |
+
logger.info(f"Video2Video inference: {args.video_path}")
|
572 |
+
video = load_video(args.video_path, 0, video_length, bucket_reso=(width, height)) # list of frames
|
573 |
+
if len(video) < video_length:
|
574 |
+
raise ValueError(f"Video length is less than {video_length}")
|
575 |
+
video = np.stack(video, axis=0) # F, H, W, C
|
576 |
+
video = torch.from_numpy(video).permute(3, 0, 1, 2).unsqueeze(0).float() # 1, C, F, H, W
|
577 |
+
video = video / 255.0
|
578 |
+
|
579 |
+
logger.info(f"Encoding video to latents")
|
580 |
+
video_latents = encode_to_latents(args, video, device)
|
581 |
+
video_latents = video_latents.to(device=device, dtype=dit_dtype)
|
582 |
+
|
583 |
+
clean_memory_on_device(device)
|
584 |
+
|
585 |
+
# encode latents for image2video inference
|
586 |
+
image_latents = None
|
587 |
+
if args.image_path is not None:
|
588 |
+
# i2v inference
|
589 |
+
logger.info(f"Image2Video inference: {args.image_path}")
|
590 |
+
|
591 |
+
image = Image.open(args.image_path)
|
592 |
+
image = resize_image_to_bucket(image, (width, height)) # returns a numpy array
|
593 |
+
image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).unsqueeze(2).float() # 1, C, 1, H, W
|
594 |
+
image = image / 255.0
|
595 |
+
|
596 |
+
logger.info(f"Encoding image to latents")
|
597 |
+
image_latents = encode_to_latents(args, image, device) # 1, C, 1, H, W
|
598 |
+
image_latents = image_latents.to(device=device, dtype=dit_dtype)
|
599 |
+
|
600 |
+
clean_memory_on_device(device)
|
601 |
+
|
602 |
+
# load DiT model
|
603 |
+
blocks_to_swap = args.blocks_to_swap if args.blocks_to_swap else 0
|
604 |
+
loading_device = "cpu" # if blocks_to_swap > 0 else device
|
605 |
+
|
606 |
+
logger.info(f"Loading DiT model from {args.dit}")
|
607 |
+
if args.attn_mode == "sdpa":
|
608 |
+
args.attn_mode = "torch"
|
609 |
+
|
610 |
+
# if image_latents is given, the model should be I2V model, so the in_channels should be 32
|
611 |
+
dit_in_channels = args.dit_in_channels if args.dit_in_channels is not None else (32 if image_latents is not None else 16)
|
612 |
+
|
613 |
+
# if we use LoRA, weigths should be bf16 instead of fp8, because merging should be done in bf16
|
614 |
+
# the model is too large, so we load the model to cpu. in addition, the .pt file is loaded to cpu anyway
|
615 |
+
# on the fly merging will be a solution for this issue for .safetenors files (not implemented yet)
|
616 |
+
transformer = load_transformer(
|
617 |
+
args.dit, args.attn_mode, args.split_attn, loading_device, dit_dtype, in_channels=dit_in_channels
|
618 |
+
)
|
619 |
+
transformer.eval()
|
620 |
+
|
621 |
+
# load LoRA weights
|
622 |
+
if args.lora_weight is not None and len(args.lora_weight) > 0:
|
623 |
+
for i, lora_weight in enumerate(args.lora_weight):
|
624 |
+
if args.lora_multiplier is not None and len(args.lora_multiplier) > i:
|
625 |
+
lora_multiplier = args.lora_multiplier[i]
|
626 |
+
else:
|
627 |
+
lora_multiplier = 1.0
|
628 |
+
|
629 |
+
logger.info(f"Loading LoRA weights from {lora_weight} with multiplier {lora_multiplier}")
|
630 |
+
weights_sd = load_file(lora_weight)
|
631 |
+
|
632 |
+
# Filter to exclude keys that are part of single_blocks
|
633 |
+
if args.exclude_single_blocks:
|
634 |
+
filtered_weights = {k: v for k, v in weights_sd.items() if "single_blocks" not in k}
|
635 |
+
weights_sd = filtered_weights
|
636 |
+
|
637 |
+
if args.lycoris:
|
638 |
+
lycoris_net, _ = create_network_from_weights(
|
639 |
+
multiplier=lora_multiplier,
|
640 |
+
file=None,
|
641 |
+
weights_sd=weights_sd,
|
642 |
+
unet=transformer,
|
643 |
+
text_encoder=None,
|
644 |
+
vae=None,
|
645 |
+
for_inference=True,
|
646 |
+
)
|
647 |
+
else:
|
648 |
+
network = lora.create_arch_network_from_weights(
|
649 |
+
lora_multiplier, weights_sd, unet=transformer, for_inference=True
|
650 |
+
)
|
651 |
+
logger.info("Merging LoRA weights to DiT model")
|
652 |
+
|
653 |
+
# try:
|
654 |
+
# network.apply_to(None, transformer, apply_text_encoder=False, apply_unet=True)
|
655 |
+
# info = network.load_state_dict(weights_sd, strict=True)
|
656 |
+
# logger.info(f"Loaded LoRA weights from {weights_file}: {info}")
|
657 |
+
# network.eval()
|
658 |
+
# network.to(device)
|
659 |
+
# except Exception as e:
|
660 |
+
if args.lycoris:
|
661 |
+
lycoris_net.merge_to(None, transformer, weights_sd, dtype=None, device=device)
|
662 |
+
else:
|
663 |
+
network.merge_to(None, transformer, weights_sd, device=device, non_blocking=True)
|
664 |
+
|
665 |
+
synchronize_device(device)
|
666 |
+
|
667 |
+
logger.info("LoRA weights loaded")
|
668 |
+
|
669 |
+
# save model here before casting to dit_weight_dtype
|
670 |
+
if args.save_merged_model:
|
671 |
+
logger.info(f"Saving merged model to {args.save_merged_model}")
|
672 |
+
mem_eff_save_file(transformer.state_dict(), args.save_merged_model) # save_file needs a lot of memory
|
673 |
+
logger.info("Merged model saved")
|
674 |
+
return
|
675 |
+
|
676 |
+
logger.info(f"Casting model to {dit_weight_dtype}")
|
677 |
+
transformer.to(dtype=dit_weight_dtype)
|
678 |
+
|
679 |
+
if args.fp8_fast:
|
680 |
+
logger.info("Enabling FP8 acceleration")
|
681 |
+
params_to_keep = {"norm", "bias", "time_in", "vector_in", "guidance_in", "txt_in", "img_in"}
|
682 |
+
for name, param in transformer.named_parameters():
|
683 |
+
dtype_to_use = dit_dtype if any(keyword in name for keyword in params_to_keep) else dit_weight_dtype
|
684 |
+
param.to(dtype=dtype_to_use)
|
685 |
+
convert_fp8_linear(transformer, dit_dtype, params_to_keep=params_to_keep)
|
686 |
+
|
687 |
+
if args.compile:
|
688 |
+
compile_backend, compile_mode, compile_dynamic, compile_fullgraph = args.compile_args
|
689 |
+
logger.info(
|
690 |
+
f"Torch Compiling[Backend: {compile_backend}; Mode: {compile_mode}; Dynamic: {compile_dynamic}; Fullgraph: {compile_fullgraph}]"
|
691 |
+
)
|
692 |
+
torch._dynamo.config.cache_size_limit = 32
|
693 |
+
for i, block in enumerate(transformer.single_blocks):
|
694 |
+
compiled_block = torch.compile(
|
695 |
+
block,
|
696 |
+
backend=compile_backend,
|
697 |
+
mode=compile_mode,
|
698 |
+
dynamic=compile_dynamic.lower() in "true",
|
699 |
+
fullgraph=compile_fullgraph.lower() in "true",
|
700 |
+
)
|
701 |
+
transformer.single_blocks[i] = compiled_block
|
702 |
+
for i, block in enumerate(transformer.double_blocks):
|
703 |
+
compiled_block = torch.compile(
|
704 |
+
block,
|
705 |
+
backend=compile_backend,
|
706 |
+
mode=compile_mode,
|
707 |
+
dynamic=compile_dynamic.lower() in "true",
|
708 |
+
fullgraph=compile_fullgraph.lower() in "true",
|
709 |
+
)
|
710 |
+
transformer.double_blocks[i] = compiled_block
|
711 |
+
|
712 |
+
if blocks_to_swap > 0:
|
713 |
+
logger.info(f"Enable swap {blocks_to_swap} blocks to CPU from device: {device}")
|
714 |
+
transformer.enable_block_swap(blocks_to_swap, device, supports_backward=False)
|
715 |
+
transformer.move_to_device_except_swap_blocks(device)
|
716 |
+
transformer.prepare_block_swap_before_forward()
|
717 |
+
else:
|
718 |
+
logger.info(f"Moving model to {device}")
|
719 |
+
transformer.to(device=device)
|
720 |
+
if args.img_in_txt_in_offloading:
|
721 |
+
logger.info("Enable offloading img_in and txt_in to CPU")
|
722 |
+
transformer.enable_img_in_txt_in_offloading()
|
723 |
+
|
724 |
+
# load scheduler
|
725 |
+
logger.info(f"Loading scheduler")
|
726 |
+
scheduler = FlowMatchDiscreteScheduler(shift=args.flow_shift, reverse=True, solver="euler")
|
727 |
+
|
728 |
+
# Prepare timesteps
|
729 |
+
num_inference_steps = args.infer_steps
|
730 |
+
scheduler.set_timesteps(num_inference_steps, device=device) # n_tokens is not used in FlowMatchDiscreteScheduler
|
731 |
+
timesteps = scheduler.timesteps
|
732 |
+
|
733 |
+
# Prepare generator
|
734 |
+
num_videos_per_prompt = 1 # args.num_videos # currently only support 1 video per prompt, this is a batch size
|
735 |
+
seed = args.seed
|
736 |
+
if seed is None:
|
737 |
+
seeds = [random.randint(0, 2**32 - 1) for _ in range(num_videos_per_prompt)]
|
738 |
+
elif isinstance(seed, int):
|
739 |
+
seeds = [seed + i for i in range(num_videos_per_prompt)]
|
740 |
+
else:
|
741 |
+
raise ValueError(f"Seed must be an integer or None, got {seed}.")
|
742 |
+
generator = [torch.Generator(device).manual_seed(seed) for seed in seeds]
|
743 |
+
|
744 |
+
# Prepare noisy latents
|
745 |
+
num_channels_latents = 16 # transformer.config.in_channels
|
746 |
+
vae_scale_factor = 2 ** (4 - 1) # len(self.vae.config.block_out_channels) == 4
|
747 |
+
|
748 |
+
vae_ver = vae.VAE_VER
|
749 |
+
if "884" in vae_ver:
|
750 |
+
latent_video_length = (video_length - 1) // 4 + 1
|
751 |
+
elif "888" in vae_ver:
|
752 |
+
latent_video_length = (video_length - 1) // 8 + 1
|
753 |
+
else:
|
754 |
+
latent_video_length = video_length
|
755 |
+
|
756 |
+
# shape = (
|
757 |
+
# num_videos_per_prompt,
|
758 |
+
# num_channels_latents,
|
759 |
+
# latent_video_length,
|
760 |
+
# height // vae_scale_factor,
|
761 |
+
# width // vae_scale_factor,
|
762 |
+
# )
|
763 |
+
# latents = randn_tensor(shape, generator=generator, device=device, dtype=dit_dtype)
|
764 |
+
|
765 |
+
# make first N frames to be the same if the given seed is same
|
766 |
+
shape_of_frame = (num_videos_per_prompt, num_channels_latents, 1, height // vae_scale_factor, width // vae_scale_factor)
|
767 |
+
latents = []
|
768 |
+
for i in range(latent_video_length):
|
769 |
+
latents.append(randn_tensor(shape_of_frame, generator=generator, device=device, dtype=dit_dtype))
|
770 |
+
latents = torch.cat(latents, dim=2)
|
771 |
+
|
772 |
+
# pad image_latents to match the length of video_latents
|
773 |
+
if image_latents is not None:
|
774 |
+
zero_latents = torch.zeros_like(latents)
|
775 |
+
zero_latents[:, :, :1, :, :] = image_latents
|
776 |
+
image_latents = zero_latents
|
777 |
+
|
778 |
+
if args.video_path is not None:
|
779 |
+
# v2v inference
|
780 |
+
noise = latents
|
781 |
+
assert noise.shape == video_latents.shape, f"noise shape {noise.shape} != video_latents shape {video_latents.shape}"
|
782 |
+
|
783 |
+
num_inference_steps = int(num_inference_steps * args.strength)
|
784 |
+
timestep_start = scheduler.timesteps[-num_inference_steps] # larger strength, less inference steps and more start time
|
785 |
+
t = timestep_start / 1000.0
|
786 |
+
latents = noise * t + video_latents * (1 - t)
|
787 |
+
|
788 |
+
timesteps = timesteps[-num_inference_steps:]
|
789 |
+
|
790 |
+
logger.info(f"strength: {args.strength}, num_inference_steps: {num_inference_steps}, timestep_start: {timestep_start}")
|
791 |
+
|
792 |
+
# FlowMatchDiscreteScheduler does not have init_noise_sigma
|
793 |
+
|
794 |
+
# Denoising loop
|
795 |
+
embedded_guidance_scale = args.embedded_cfg_scale
|
796 |
+
if embedded_guidance_scale is not None:
|
797 |
+
guidance_expand = torch.tensor([embedded_guidance_scale * 1000.0] * latents.shape[0], dtype=torch.float32, device="cpu")
|
798 |
+
guidance_expand = guidance_expand.to(device=device, dtype=dit_dtype)
|
799 |
+
if do_classifier_free_guidance:
|
800 |
+
guidance_expand = torch.cat([guidance_expand, guidance_expand], dim=0)
|
801 |
+
else:
|
802 |
+
guidance_expand = None
|
803 |
+
freqs_cos, freqs_sin = get_rotary_pos_embed(vae_ver, transformer, video_length, height, width)
|
804 |
+
# n_tokens = freqs_cos.shape[0]
|
805 |
+
|
806 |
+
# move and cast all inputs to the correct device and dtype
|
807 |
+
prompt_embeds = prompt_embeds.to(device=device, dtype=dit_dtype)
|
808 |
+
prompt_mask = prompt_mask.to(device=device)
|
809 |
+
prompt_embeds_2 = prompt_embeds_2.to(device=device, dtype=dit_dtype)
|
810 |
+
prompt_mask_2 = prompt_mask_2.to(device=device)
|
811 |
+
|
812 |
+
freqs_cos = freqs_cos.to(device=device, dtype=dit_dtype)
|
813 |
+
freqs_sin = freqs_sin.to(device=device, dtype=dit_dtype)
|
814 |
+
|
815 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * scheduler.order # this should be 0 in v2v inference
|
816 |
+
|
817 |
+
# assert split_uncond and split_attn
|
818 |
+
if args.split_attn and do_classifier_free_guidance and not args.split_uncond:
|
819 |
+
logger.warning("split_attn is enabled, split_uncond will be enabled as well.")
|
820 |
+
args.split_uncond = True
|
821 |
+
|
822 |
+
# with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA]) as p:
|
823 |
+
with tqdm(total=num_inference_steps) as progress_bar:
|
824 |
+
for i, t in enumerate(timesteps):
|
825 |
+
latents = scheduler.scale_model_input(latents, t)
|
826 |
+
|
827 |
+
# predict the noise residual
|
828 |
+
with torch.no_grad(), accelerator.autocast():
|
829 |
+
latents_input = latents if not do_classifier_free_guidance else torch.cat([latents, latents], dim=0)
|
830 |
+
if image_latents is not None:
|
831 |
+
latents_image_input = (
|
832 |
+
image_latents if not do_classifier_free_guidance else torch.cat([image_latents, image_latents], dim=0)
|
833 |
+
)
|
834 |
+
latents_input = torch.cat([latents_input, latents_image_input], dim=1) # 1 or 2, C*2, F, H, W
|
835 |
+
|
836 |
+
batch_size = 1 if args.split_uncond else latents_input.shape[0]
|
837 |
+
|
838 |
+
noise_pred_list = []
|
839 |
+
for j in range(0, latents_input.shape[0], batch_size):
|
840 |
+
noise_pred = transformer( # For an input image (129, 192, 336) (1, 256, 256)
|
841 |
+
latents_input[j : j + batch_size], # [1, 16, 33, 24, 42]
|
842 |
+
t.repeat(batch_size).to(device=device, dtype=dit_dtype), # [1]
|
843 |
+
text_states=prompt_embeds[j : j + batch_size], # [1, 256, 4096]
|
844 |
+
text_mask=prompt_mask[j : j + batch_size], # [1, 256]
|
845 |
+
text_states_2=prompt_embeds_2[j : j + batch_size], # [1, 768]
|
846 |
+
freqs_cos=freqs_cos, # [seqlen, head_dim]
|
847 |
+
freqs_sin=freqs_sin, # [seqlen, head_dim]
|
848 |
+
guidance=guidance_expand[j : j + batch_size], # [1]
|
849 |
+
return_dict=True,
|
850 |
+
)["x"]
|
851 |
+
noise_pred_list.append(noise_pred)
|
852 |
+
noise_pred = torch.cat(noise_pred_list, dim=0)
|
853 |
+
|
854 |
+
# perform classifier free guidance
|
855 |
+
if do_classifier_free_guidance:
|
856 |
+
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
857 |
+
noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
858 |
+
|
859 |
+
# # SkyReels' rescale noise config is omitted for now
|
860 |
+
# if guidance_rescale > 0.0:
|
861 |
+
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
862 |
+
# noise_pred = rescale_noise_cfg(
|
863 |
+
# noise_pred,
|
864 |
+
# noise_pred_cond,
|
865 |
+
# guidance_rescale=self.guidance_rescale,
|
866 |
+
# )
|
867 |
+
|
868 |
+
# compute the previous noisy sample x_t -> x_t-1
|
869 |
+
latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
870 |
+
|
871 |
+
# update progress bar
|
872 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
873 |
+
if progress_bar is not None:
|
874 |
+
progress_bar.update()
|
875 |
+
|
876 |
+
# print(p.key_averages().table(sort_by="self_cpu_time_total", row_limit=-1))
|
877 |
+
# print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
|
878 |
+
|
879 |
+
latents = latents.detach().cpu()
|
880 |
+
transformer = None
|
881 |
+
clean_memory_on_device(device)
|
882 |
+
|
883 |
+
# Save samples
|
884 |
+
output_type = args.output_type
|
885 |
+
save_path = args.save_path # if args.save_path_suffix == "" else f"{args.save_path}_{args.save_path_suffix}"
|
886 |
+
os.makedirs(save_path, exist_ok=True)
|
887 |
+
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S")
|
888 |
+
|
889 |
+
if output_type == "latent" or output_type == "both":
|
890 |
+
# save latent
|
891 |
+
for i, latent in enumerate(latents):
|
892 |
+
latent_path = f"{save_path}/{time_flag}_{i}_{seeds[i]}_latent.safetensors"
|
893 |
+
|
894 |
+
if args.no_metadata:
|
895 |
+
metadata = None
|
896 |
+
else:
|
897 |
+
metadata = {
|
898 |
+
"seeds": f"{seeds[i]}",
|
899 |
+
"prompt": f"{args.prompt}",
|
900 |
+
"height": f"{height}",
|
901 |
+
"width": f"{width}",
|
902 |
+
"video_length": f"{video_length}",
|
903 |
+
"infer_steps": f"{num_inference_steps}",
|
904 |
+
"guidance_scale": f"{args.guidance_scale}",
|
905 |
+
"embedded_cfg_scale": f"{args.embedded_cfg_scale}",
|
906 |
+
}
|
907 |
+
if args.negative_prompt is not None:
|
908 |
+
metadata["negative_prompt"] = f"{args.negative_prompt}"
|
909 |
+
sd = {"latent": latent}
|
910 |
+
save_file(sd, latent_path, metadata=metadata)
|
911 |
+
|
912 |
+
logger.info(f"Latent save to: {latent_path}")
|
913 |
+
if output_type == "video" or output_type == "both":
|
914 |
+
# save video
|
915 |
+
videos = decode_latents(args, latents, device)
|
916 |
+
for i, sample in enumerate(videos):
|
917 |
+
original_name = "" if original_base_names is None else f"_{original_base_names[i]}"
|
918 |
+
sample = sample.unsqueeze(0)
|
919 |
+
video_path = f"{save_path}/{time_flag}_{i}_{seeds[i]}{original_name}.mp4"
|
920 |
+
save_videos_grid(sample, video_path, fps=args.fps)
|
921 |
+
logger.info(f"Sample save to: {video_path}")
|
922 |
+
elif output_type == "images":
|
923 |
+
# save images
|
924 |
+
videos = decode_latents(args, latents, device)
|
925 |
+
for i, sample in enumerate(videos):
|
926 |
+
original_name = "" if original_base_names is None else f"_{original_base_names[i]}"
|
927 |
+
sample = sample.unsqueeze(0)
|
928 |
+
image_name = f"{time_flag}_{i}_{seeds[i]}{original_name}"
|
929 |
+
save_images_grid(sample, save_path, image_name)
|
930 |
+
logger.info(f"Sample images save to: {save_path}/{image_name}")
|
931 |
+
|
932 |
+
logger.info("Done!")
|
933 |
+
|
934 |
+
|
935 |
+
if __name__ == "__main__":
|
936 |
+
main()
|
base_wan_generate_video.py
ADDED
@@ -0,0 +1,1892 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from datetime import datetime
|
3 |
+
import gc
|
4 |
+
import random
|
5 |
+
import os
|
6 |
+
import re
|
7 |
+
import time
|
8 |
+
import math
|
9 |
+
import copy
|
10 |
+
from types import ModuleType, SimpleNamespace
|
11 |
+
from typing import Tuple, Optional, List, Union, Any, Dict
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import accelerate
|
15 |
+
from accelerate import Accelerator
|
16 |
+
from safetensors.torch import load_file, save_file
|
17 |
+
from safetensors import safe_open
|
18 |
+
from PIL import Image
|
19 |
+
import cv2
|
20 |
+
import numpy as np
|
21 |
+
import torchvision.transforms.functional as TF
|
22 |
+
from tqdm import tqdm
|
23 |
+
|
24 |
+
from networks import lora_wan
|
25 |
+
from utils.safetensors_utils import mem_eff_save_file, load_safetensors
|
26 |
+
from wan.configs import WAN_CONFIGS, SUPPORTED_SIZES
|
27 |
+
import wan
|
28 |
+
from wan.modules.model import WanModel, load_wan_model, detect_wan_sd_dtype
|
29 |
+
from wan.modules.vae import WanVAE
|
30 |
+
from wan.modules.t5 import T5EncoderModel
|
31 |
+
from wan.modules.clip import CLIPModel
|
32 |
+
from modules.scheduling_flow_match_discrete import FlowMatchDiscreteScheduler
|
33 |
+
from wan.utils.fm_solvers import FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps
|
34 |
+
from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
35 |
+
|
36 |
+
try:
|
37 |
+
from lycoris.kohya import create_network_from_weights
|
38 |
+
except:
|
39 |
+
pass
|
40 |
+
|
41 |
+
from utils.model_utils import str_to_dtype
|
42 |
+
from utils.device_utils import clean_memory_on_device
|
43 |
+
from hv_generate_video import save_images_grid, save_videos_grid, synchronize_device
|
44 |
+
from dataset.image_video_dataset import load_video
|
45 |
+
|
46 |
+
import logging
|
47 |
+
|
48 |
+
logger = logging.getLogger(__name__)
|
49 |
+
logging.basicConfig(level=logging.INFO)
|
50 |
+
|
51 |
+
|
52 |
+
class GenerationSettings:
|
53 |
+
def __init__(
|
54 |
+
self, device: torch.device, cfg, dit_dtype: torch.dtype, dit_weight_dtype: Optional[torch.dtype], vae_dtype: torch.dtype
|
55 |
+
):
|
56 |
+
self.device = device
|
57 |
+
self.cfg = cfg
|
58 |
+
self.dit_dtype = dit_dtype
|
59 |
+
self.dit_weight_dtype = dit_weight_dtype
|
60 |
+
self.vae_dtype = vae_dtype
|
61 |
+
|
62 |
+
|
63 |
+
def parse_args() -> argparse.Namespace:
|
64 |
+
"""parse command line arguments"""
|
65 |
+
parser = argparse.ArgumentParser(description="Wan 2.1 inference script")
|
66 |
+
|
67 |
+
# WAN arguments
|
68 |
+
parser.add_argument("--ckpt_dir", type=str, default=None, help="The path to the checkpoint directory (Wan 2.1 official).")
|
69 |
+
parser.add_argument("--task", type=str, default="t2v-14B", choices=list(WAN_CONFIGS.keys()), help="The task to run.")
|
70 |
+
parser.add_argument(
|
71 |
+
"--sample_solver", type=str, default="unipc", choices=["unipc", "dpm++", "vanilla"], help="The solver used to sample."
|
72 |
+
)
|
73 |
+
|
74 |
+
parser.add_argument("--dit", type=str, default=None, help="DiT checkpoint path")
|
75 |
+
parser.add_argument("--vae", type=str, default=None, help="VAE checkpoint path")
|
76 |
+
parser.add_argument("--vae_dtype", type=str, default=None, help="data type for VAE, default is bfloat16")
|
77 |
+
parser.add_argument("--vae_cache_cpu", action="store_true", help="cache features in VAE on CPU")
|
78 |
+
parser.add_argument("--t5", type=str, default=None, help="text encoder (T5) checkpoint path")
|
79 |
+
parser.add_argument("--clip", type=str, default=None, help="text encoder (CLIP) checkpoint path")
|
80 |
+
# LoRA
|
81 |
+
parser.add_argument("--lora_weight", type=str, nargs="*", required=False, default=None, help="LoRA weight path")
|
82 |
+
parser.add_argument("--lora_multiplier", type=float, nargs="*", default=1.0, help="LoRA multiplier")
|
83 |
+
parser.add_argument("--include_patterns", type=str, nargs="*", default=None, help="LoRA module include patterns")
|
84 |
+
parser.add_argument("--exclude_patterns", type=str, nargs="*", default=None, help="LoRA module exclude patterns")
|
85 |
+
parser.add_argument(
|
86 |
+
"--save_merged_model",
|
87 |
+
type=str,
|
88 |
+
default=None,
|
89 |
+
help="Save merged model to path. If specified, no inference will be performed.",
|
90 |
+
)
|
91 |
+
|
92 |
+
# inference
|
93 |
+
parser.add_argument("--prompt", type=str, default=None, help="prompt for generation")
|
94 |
+
parser.add_argument(
|
95 |
+
"--negative_prompt",
|
96 |
+
type=str,
|
97 |
+
default=None,
|
98 |
+
help="negative prompt for generation, use default negative prompt if not specified",
|
99 |
+
)
|
100 |
+
parser.add_argument("--video_size", type=int, nargs=2, default=[256, 256], help="video size, height and width")
|
101 |
+
parser.add_argument("--video_length", type=int, default=None, help="video length, Default depends on task")
|
102 |
+
parser.add_argument("--fps", type=int, default=16, help="video fps, Default is 16")
|
103 |
+
parser.add_argument("--infer_steps", type=int, default=None, help="number of inference steps")
|
104 |
+
parser.add_argument("--save_path", type=str, required=True, help="path to save generated video")
|
105 |
+
parser.add_argument("--seed", type=int, default=None, help="Seed for evaluation.")
|
106 |
+
parser.add_argument(
|
107 |
+
"--cpu_noise", action="store_true", help="Use CPU to generate noise (compatible with ComfyUI). Default is False."
|
108 |
+
)
|
109 |
+
parser.add_argument(
|
110 |
+
"--guidance_scale",
|
111 |
+
type=float,
|
112 |
+
default=5.0,
|
113 |
+
help="Guidance scale for classifier free guidance. Default is 5.0.",
|
114 |
+
)
|
115 |
+
parser.add_argument("--video_path", type=str, default=None, help="path to video for video2video inference")
|
116 |
+
parser.add_argument("--image_path", type=str, default=None, help="path to image for image2video inference")
|
117 |
+
parser.add_argument("--end_image_path", type=str, default=None, help="path to end image for image2video inference")
|
118 |
+
parser.add_argument(
|
119 |
+
"--control_path",
|
120 |
+
type=str,
|
121 |
+
default=None,
|
122 |
+
help="path to control video for inference with controlnet. video file or directory with images",
|
123 |
+
)
|
124 |
+
parser.add_argument("--trim_tail_frames", type=int, default=0, help="trim tail N frames from the video before saving")
|
125 |
+
parser.add_argument(
|
126 |
+
"--cfg_skip_mode",
|
127 |
+
type=str,
|
128 |
+
default="none",
|
129 |
+
choices=["early", "late", "middle", "early_late", "alternate", "none"],
|
130 |
+
help="CFG skip mode. each mode skips different parts of the CFG. "
|
131 |
+
" early: initial steps, late: later steps, middle: middle steps, early_late: both early and late, alternate: alternate, none: no skip (default)",
|
132 |
+
)
|
133 |
+
parser.add_argument(
|
134 |
+
"--cfg_apply_ratio",
|
135 |
+
type=float,
|
136 |
+
default=None,
|
137 |
+
help="The ratio of steps to apply CFG (0.0 to 1.0). Default is None (apply all steps).",
|
138 |
+
)
|
139 |
+
parser.add_argument(
|
140 |
+
"--slg_layers", type=str, default=None, help="Skip block (layer) indices for SLG (Skip Layer Guidance), comma separated"
|
141 |
+
)
|
142 |
+
parser.add_argument(
|
143 |
+
"--slg_scale",
|
144 |
+
type=float,
|
145 |
+
default=3.0,
|
146 |
+
help="scale for SLG classifier free guidance. Default is 3.0. Ignored if slg_mode is None or uncond",
|
147 |
+
)
|
148 |
+
parser.add_argument("--slg_start", type=float, default=0.0, help="start ratio for inference steps for SLG. Default is 0.0.")
|
149 |
+
parser.add_argument("--slg_end", type=float, default=0.3, help="end ratio for inference steps for SLG. Default is 0.3.")
|
150 |
+
parser.add_argument(
|
151 |
+
"--slg_mode",
|
152 |
+
type=str,
|
153 |
+
default=None,
|
154 |
+
choices=["original", "uncond"],
|
155 |
+
help="SLG mode. original: same as SD3, uncond: replace uncond pred with SLG pred",
|
156 |
+
)
|
157 |
+
|
158 |
+
# Flow Matching
|
159 |
+
parser.add_argument(
|
160 |
+
"--flow_shift",
|
161 |
+
type=float,
|
162 |
+
default=None,
|
163 |
+
help="Shift factor for flow matching schedulers. Default depends on task.",
|
164 |
+
)
|
165 |
+
|
166 |
+
parser.add_argument("--fp8", action="store_true", help="use fp8 for DiT model")
|
167 |
+
parser.add_argument("--fp8_scaled", action="store_true", help="use scaled fp8 for DiT, only for fp8")
|
168 |
+
parser.add_argument("--fp8_fast", action="store_true", help="Enable fast FP8 arithmetic (RTX 4XXX+), only for fp8_scaled")
|
169 |
+
parser.add_argument("--fp8_t5", action="store_true", help="use fp8 for Text Encoder model")
|
170 |
+
parser.add_argument(
|
171 |
+
"--device", type=str, default=None, help="device to use for inference. If None, use CUDA if available, otherwise use CPU"
|
172 |
+
)
|
173 |
+
parser.add_argument(
|
174 |
+
"--attn_mode",
|
175 |
+
type=str,
|
176 |
+
default="torch",
|
177 |
+
choices=["flash", "flash2", "flash3", "torch", "sageattn", "xformers", "sdpa"],
|
178 |
+
help="attention mode",
|
179 |
+
)
|
180 |
+
parser.add_argument("--blocks_to_swap", type=int, default=0, help="number of blocks to swap in the model")
|
181 |
+
parser.add_argument(
|
182 |
+
"--output_type", type=str, default="video", choices=["video", "images", "latent", "both"], help="output type"
|
183 |
+
)
|
184 |
+
parser.add_argument("--no_metadata", action="store_true", help="do not save metadata")
|
185 |
+
parser.add_argument("--latent_path", type=str, nargs="*", default=None, help="path to latent for decode. no inference")
|
186 |
+
parser.add_argument("--lycoris", action="store_true", help="use lycoris for inference")
|
187 |
+
parser.add_argument("--compile", action="store_true", help="Enable torch.compile")
|
188 |
+
parser.add_argument(
|
189 |
+
"--compile_args",
|
190 |
+
nargs=4,
|
191 |
+
metavar=("BACKEND", "MODE", "DYNAMIC", "FULLGRAPH"),
|
192 |
+
default=["inductor", "max-autotune-no-cudagraphs", "False", "False"],
|
193 |
+
help="Torch.compile settings",
|
194 |
+
)
|
195 |
+
|
196 |
+
# New arguments for batch and interactive modes
|
197 |
+
parser.add_argument("--from_file", type=str, default=None, help="Read prompts from a file")
|
198 |
+
parser.add_argument("--interactive", action="store_true", help="Interactive mode: read prompts from console")
|
199 |
+
|
200 |
+
args = parser.parse_args()
|
201 |
+
|
202 |
+
# Validate arguments
|
203 |
+
if args.from_file and args.interactive:
|
204 |
+
raise ValueError("Cannot use both --from_file and --interactive at the same time")
|
205 |
+
|
206 |
+
if args.prompt is None and not args.from_file and not args.interactive and args.latent_path is None:
|
207 |
+
raise ValueError("Either --prompt, --from_file, --interactive, or --latent_path must be specified")
|
208 |
+
|
209 |
+
assert (args.latent_path is None or len(args.latent_path) == 0) or (
|
210 |
+
args.output_type == "images" or args.output_type == "video"
|
211 |
+
), "latent_path is only supported for images or video output"
|
212 |
+
|
213 |
+
return args
|
214 |
+
|
215 |
+
|
216 |
+
def parse_prompt_line(line: str) -> Dict[str, Any]:
|
217 |
+
"""Parse a prompt line into a dictionary of argument overrides
|
218 |
+
|
219 |
+
Args:
|
220 |
+
line: Prompt line with options
|
221 |
+
|
222 |
+
Returns:
|
223 |
+
Dict[str, Any]: Dictionary of argument overrides
|
224 |
+
"""
|
225 |
+
# TODO common function with hv_train_network.line_to_prompt_dict
|
226 |
+
parts = line.split(" --")
|
227 |
+
prompt = parts[0].strip()
|
228 |
+
|
229 |
+
# Create dictionary of overrides
|
230 |
+
overrides = {"prompt": prompt}
|
231 |
+
|
232 |
+
for part in parts[1:]:
|
233 |
+
if not part.strip():
|
234 |
+
continue
|
235 |
+
option_parts = part.split(" ", 1)
|
236 |
+
option = option_parts[0].strip()
|
237 |
+
value = option_parts[1].strip() if len(option_parts) > 1 else ""
|
238 |
+
|
239 |
+
# Map options to argument names
|
240 |
+
if option == "w":
|
241 |
+
overrides["video_size_width"] = int(value)
|
242 |
+
elif option == "h":
|
243 |
+
overrides["video_size_height"] = int(value)
|
244 |
+
elif option == "f":
|
245 |
+
overrides["video_length"] = int(value)
|
246 |
+
elif option == "d":
|
247 |
+
overrides["seed"] = int(value)
|
248 |
+
elif option == "s":
|
249 |
+
overrides["infer_steps"] = int(value)
|
250 |
+
elif option == "g" or option == "l":
|
251 |
+
overrides["guidance_scale"] = float(value)
|
252 |
+
elif option == "fs":
|
253 |
+
overrides["flow_shift"] = float(value)
|
254 |
+
elif option == "i":
|
255 |
+
overrides["image_path"] = value
|
256 |
+
elif option == "cn":
|
257 |
+
overrides["control_path"] = value
|
258 |
+
elif option == "n":
|
259 |
+
overrides["negative_prompt"] = value
|
260 |
+
|
261 |
+
return overrides
|
262 |
+
|
263 |
+
|
264 |
+
def apply_overrides(args: argparse.Namespace, overrides: Dict[str, Any]) -> argparse.Namespace:
|
265 |
+
"""Apply overrides to args
|
266 |
+
|
267 |
+
Args:
|
268 |
+
args: Original arguments
|
269 |
+
overrides: Dictionary of overrides
|
270 |
+
|
271 |
+
Returns:
|
272 |
+
argparse.Namespace: New arguments with overrides applied
|
273 |
+
"""
|
274 |
+
args_copy = copy.deepcopy(args)
|
275 |
+
|
276 |
+
for key, value in overrides.items():
|
277 |
+
if key == "video_size_width":
|
278 |
+
args_copy.video_size[1] = value
|
279 |
+
elif key == "video_size_height":
|
280 |
+
args_copy.video_size[0] = value
|
281 |
+
else:
|
282 |
+
setattr(args_copy, key, value)
|
283 |
+
|
284 |
+
return args_copy
|
285 |
+
|
286 |
+
|
287 |
+
def get_task_defaults(task: str, size: Optional[Tuple[int, int]] = None) -> Tuple[int, float, int, bool]:
|
288 |
+
"""Return default values for each task
|
289 |
+
|
290 |
+
Args:
|
291 |
+
task: task name (t2v, t2i, i2v etc.)
|
292 |
+
size: size of the video (width, height)
|
293 |
+
|
294 |
+
Returns:
|
295 |
+
Tuple[int, float, int, bool]: (infer_steps, flow_shift, video_length, needs_clip)
|
296 |
+
"""
|
297 |
+
width, height = size if size else (0, 0)
|
298 |
+
|
299 |
+
if "t2i" in task:
|
300 |
+
return 50, 5.0, 1, False
|
301 |
+
elif "i2v" in task:
|
302 |
+
flow_shift = 3.0 if (width == 832 and height == 480) or (width == 480 and height == 832) else 5.0
|
303 |
+
return 40, flow_shift, 81, True
|
304 |
+
else: # t2v or default
|
305 |
+
return 50, 5.0, 81, False
|
306 |
+
|
307 |
+
|
308 |
+
def setup_args(args: argparse.Namespace) -> argparse.Namespace:
|
309 |
+
"""Validate and set default values for optional arguments
|
310 |
+
|
311 |
+
Args:
|
312 |
+
args: command line arguments
|
313 |
+
|
314 |
+
Returns:
|
315 |
+
argparse.Namespace: updated arguments
|
316 |
+
"""
|
317 |
+
# Get default values for the task
|
318 |
+
infer_steps, flow_shift, video_length, _ = get_task_defaults(args.task, tuple(args.video_size))
|
319 |
+
|
320 |
+
# Apply default values to unset arguments
|
321 |
+
if args.infer_steps is None:
|
322 |
+
args.infer_steps = infer_steps
|
323 |
+
if args.flow_shift is None:
|
324 |
+
args.flow_shift = flow_shift
|
325 |
+
if args.video_length is None:
|
326 |
+
args.video_length = video_length
|
327 |
+
|
328 |
+
# Force video_length to 1 for t2i tasks
|
329 |
+
if "t2i" in args.task:
|
330 |
+
assert args.video_length == 1, f"video_length should be 1 for task {args.task}"
|
331 |
+
|
332 |
+
# parse slg_layers
|
333 |
+
if args.slg_layers is not None:
|
334 |
+
args.slg_layers = list(map(int, args.slg_layers.split(",")))
|
335 |
+
|
336 |
+
return args
|
337 |
+
|
338 |
+
|
339 |
+
def check_inputs(args: argparse.Namespace) -> Tuple[int, int, int]:
|
340 |
+
"""Validate video size and length
|
341 |
+
|
342 |
+
Args:
|
343 |
+
args: command line arguments
|
344 |
+
|
345 |
+
Returns:
|
346 |
+
Tuple[int, int, int]: (height, width, video_length)
|
347 |
+
"""
|
348 |
+
height = args.video_size[0]
|
349 |
+
width = args.video_size[1]
|
350 |
+
size = f"{width}*{height}"
|
351 |
+
|
352 |
+
if size not in SUPPORTED_SIZES[args.task]:
|
353 |
+
logger.warning(f"Size {size} is not supported for task {args.task}. Supported sizes are {SUPPORTED_SIZES[args.task]}.")
|
354 |
+
|
355 |
+
video_length = args.video_length
|
356 |
+
|
357 |
+
if height % 8 != 0 or width % 8 != 0:
|
358 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
359 |
+
|
360 |
+
return height, width, video_length
|
361 |
+
|
362 |
+
|
363 |
+
def calculate_dimensions(video_size: Tuple[int, int], video_length: int, config) -> Tuple[Tuple[int, int, int, int], int]:
|
364 |
+
"""calculate dimensions for the generation
|
365 |
+
|
366 |
+
Args:
|
367 |
+
video_size: video frame size (height, width)
|
368 |
+
video_length: number of frames in the video
|
369 |
+
config: model configuration
|
370 |
+
|
371 |
+
Returns:
|
372 |
+
Tuple[Tuple[int, int, int, int], int]:
|
373 |
+
((channels, frames, height, width), seq_len)
|
374 |
+
"""
|
375 |
+
height, width = video_size
|
376 |
+
frames = video_length
|
377 |
+
|
378 |
+
# calculate latent space dimensions
|
379 |
+
lat_f = (frames - 1) // config.vae_stride[0] + 1
|
380 |
+
lat_h = height // config.vae_stride[1]
|
381 |
+
lat_w = width // config.vae_stride[2]
|
382 |
+
|
383 |
+
# calculate sequence length
|
384 |
+
seq_len = math.ceil((lat_h * lat_w) / (config.patch_size[1] * config.patch_size[2]) * lat_f)
|
385 |
+
|
386 |
+
return ((16, lat_f, lat_h, lat_w), seq_len)
|
387 |
+
|
388 |
+
|
389 |
+
def load_vae(args: argparse.Namespace, config, device: torch.device, dtype: torch.dtype) -> WanVAE:
|
390 |
+
"""load VAE model
|
391 |
+
|
392 |
+
Args:
|
393 |
+
args: command line arguments
|
394 |
+
config: model configuration
|
395 |
+
device: device to use
|
396 |
+
dtype: data type for the model
|
397 |
+
|
398 |
+
Returns:
|
399 |
+
WanVAE: loaded VAE model
|
400 |
+
"""
|
401 |
+
vae_path = args.vae if args.vae is not None else os.path.join(args.ckpt_dir, config.vae_checkpoint)
|
402 |
+
|
403 |
+
logger.info(f"Loading VAE model from {vae_path}")
|
404 |
+
cache_device = torch.device("cpu") if args.vae_cache_cpu else None
|
405 |
+
vae = WanVAE(vae_path=vae_path, device=device, dtype=dtype, cache_device=cache_device)
|
406 |
+
return vae
|
407 |
+
|
408 |
+
|
409 |
+
def load_text_encoder(args: argparse.Namespace, config, device: torch.device) -> T5EncoderModel:
|
410 |
+
"""load text encoder (T5) model
|
411 |
+
|
412 |
+
Args:
|
413 |
+
args: command line arguments
|
414 |
+
config: model configuration
|
415 |
+
device: device to use
|
416 |
+
|
417 |
+
Returns:
|
418 |
+
T5EncoderModel: loaded text encoder model
|
419 |
+
"""
|
420 |
+
checkpoint_path = None if args.ckpt_dir is None else os.path.join(args.ckpt_dir, config.t5_checkpoint)
|
421 |
+
tokenizer_path = None if args.ckpt_dir is None else os.path.join(args.ckpt_dir, config.t5_tokenizer)
|
422 |
+
|
423 |
+
text_encoder = T5EncoderModel(
|
424 |
+
text_len=config.text_len,
|
425 |
+
dtype=config.t5_dtype,
|
426 |
+
device=device,
|
427 |
+
checkpoint_path=checkpoint_path,
|
428 |
+
tokenizer_path=tokenizer_path,
|
429 |
+
weight_path=args.t5,
|
430 |
+
fp8=args.fp8_t5,
|
431 |
+
)
|
432 |
+
|
433 |
+
return text_encoder
|
434 |
+
|
435 |
+
|
436 |
+
def load_clip_model(args: argparse.Namespace, config, device: torch.device) -> CLIPModel:
|
437 |
+
"""load CLIP model (for I2V only)
|
438 |
+
|
439 |
+
Args:
|
440 |
+
args: command line arguments
|
441 |
+
config: model configuration
|
442 |
+
device: device to use
|
443 |
+
|
444 |
+
Returns:
|
445 |
+
CLIPModel: loaded CLIP model
|
446 |
+
"""
|
447 |
+
checkpoint_path = None if args.ckpt_dir is None else os.path.join(args.ckpt_dir, config.clip_checkpoint)
|
448 |
+
tokenizer_path = None if args.ckpt_dir is None else os.path.join(args.ckpt_dir, config.clip_tokenizer)
|
449 |
+
|
450 |
+
clip = CLIPModel(
|
451 |
+
dtype=config.clip_dtype,
|
452 |
+
device=device,
|
453 |
+
checkpoint_path=checkpoint_path,
|
454 |
+
tokenizer_path=tokenizer_path,
|
455 |
+
weight_path=args.clip,
|
456 |
+
)
|
457 |
+
|
458 |
+
return clip
|
459 |
+
|
460 |
+
|
461 |
+
def load_dit_model(
|
462 |
+
args: argparse.Namespace,
|
463 |
+
config,
|
464 |
+
device: torch.device,
|
465 |
+
dit_dtype: torch.dtype,
|
466 |
+
dit_weight_dtype: Optional[torch.dtype] = None,
|
467 |
+
is_i2v: bool = False,
|
468 |
+
) -> WanModel:
|
469 |
+
"""load DiT model
|
470 |
+
|
471 |
+
Args:
|
472 |
+
args: command line arguments
|
473 |
+
config: model configuration
|
474 |
+
device: device to use
|
475 |
+
dit_dtype: data type for the model
|
476 |
+
dit_weight_dtype: data type for the model weights. None for as-is
|
477 |
+
is_i2v: I2V mode
|
478 |
+
|
479 |
+
Returns:
|
480 |
+
WanModel: loaded DiT model
|
481 |
+
"""
|
482 |
+
loading_device = "cpu"
|
483 |
+
if args.blocks_to_swap == 0 and args.lora_weight is None and not args.fp8_scaled:
|
484 |
+
loading_device = device
|
485 |
+
|
486 |
+
loading_weight_dtype = dit_weight_dtype
|
487 |
+
if args.fp8_scaled or args.lora_weight is not None:
|
488 |
+
loading_weight_dtype = dit_dtype # load as-is
|
489 |
+
|
490 |
+
# do not fp8 optimize because we will merge LoRA weights
|
491 |
+
model = load_wan_model(config, device, args.dit, args.attn_mode, False, loading_device, loading_weight_dtype, False)
|
492 |
+
|
493 |
+
return model
|
494 |
+
|
495 |
+
|
496 |
+
def merge_lora_weights(lora_module: ModuleType, model: torch.nn.Module, args: argparse.Namespace, device: torch.device) -> None:
|
497 |
+
"""merge LoRA weights to the model
|
498 |
+
|
499 |
+
Args:
|
500 |
+
model: DiT model
|
501 |
+
args: command line arguments
|
502 |
+
device: device to use
|
503 |
+
"""
|
504 |
+
if args.lora_weight is None or len(args.lora_weight) == 0:
|
505 |
+
return
|
506 |
+
|
507 |
+
for i, lora_weight in enumerate(args.lora_weight):
|
508 |
+
if args.lora_multiplier is not None and len(args.lora_multiplier) > i:
|
509 |
+
lora_multiplier = args.lora_multiplier[i]
|
510 |
+
else:
|
511 |
+
lora_multiplier = 1.0
|
512 |
+
|
513 |
+
logger.info(f"Loading LoRA weights from {lora_weight} with multiplier {lora_multiplier}")
|
514 |
+
weights_sd = load_file(lora_weight)
|
515 |
+
|
516 |
+
# apply include/exclude patterns
|
517 |
+
original_key_count = len(weights_sd.keys())
|
518 |
+
if args.include_patterns is not None and len(args.include_patterns) > i:
|
519 |
+
include_pattern = args.include_patterns[i]
|
520 |
+
regex_include = re.compile(include_pattern)
|
521 |
+
weights_sd = {k: v for k, v in weights_sd.items() if regex_include.search(k)}
|
522 |
+
logger.info(f"Filtered keys with include pattern {include_pattern}: {original_key_count} -> {len(weights_sd.keys())}")
|
523 |
+
if args.exclude_patterns is not None and len(args.exclude_patterns) > i:
|
524 |
+
original_key_count_ex = len(weights_sd.keys())
|
525 |
+
exclude_pattern = args.exclude_patterns[i]
|
526 |
+
regex_exclude = re.compile(exclude_pattern)
|
527 |
+
weights_sd = {k: v for k, v in weights_sd.items() if not regex_exclude.search(k)}
|
528 |
+
logger.info(
|
529 |
+
f"Filtered keys with exclude pattern {exclude_pattern}: {original_key_count_ex} -> {len(weights_sd.keys())}"
|
530 |
+
)
|
531 |
+
if len(weights_sd) != original_key_count:
|
532 |
+
remaining_keys = list(set([k.split(".", 1)[0] for k in weights_sd.keys()]))
|
533 |
+
remaining_keys.sort()
|
534 |
+
logger.info(f"Remaining LoRA modules after filtering: {remaining_keys}")
|
535 |
+
if len(weights_sd) == 0:
|
536 |
+
logger.warning(f"No keys left after filtering.")
|
537 |
+
|
538 |
+
if args.lycoris:
|
539 |
+
lycoris_net, _ = create_network_from_weights(
|
540 |
+
multiplier=lora_multiplier,
|
541 |
+
file=None,
|
542 |
+
weights_sd=weights_sd,
|
543 |
+
unet=model,
|
544 |
+
text_encoder=None,
|
545 |
+
vae=None,
|
546 |
+
for_inference=True,
|
547 |
+
)
|
548 |
+
lycoris_net.merge_to(None, model, weights_sd, dtype=None, device=device)
|
549 |
+
else:
|
550 |
+
network = lora_module.create_arch_network_from_weights(lora_multiplier, weights_sd, unet=model, for_inference=True)
|
551 |
+
network.merge_to(None, model, weights_sd, device=device, non_blocking=True)
|
552 |
+
|
553 |
+
synchronize_device(device)
|
554 |
+
logger.info("LoRA weights loaded")
|
555 |
+
|
556 |
+
# save model here before casting to dit_weight_dtype
|
557 |
+
if args.save_merged_model:
|
558 |
+
logger.info(f"Saving merged model to {args.save_merged_model}")
|
559 |
+
mem_eff_save_file(model.state_dict(), args.save_merged_model) # save_file needs a lot of memory
|
560 |
+
logger.info("Merged model saved")
|
561 |
+
|
562 |
+
|
563 |
+
def optimize_model(
|
564 |
+
model: WanModel, args: argparse.Namespace, device: torch.device, dit_dtype: torch.dtype, dit_weight_dtype: torch.dtype
|
565 |
+
) -> None:
|
566 |
+
"""optimize the model (FP8 conversion, device move etc.)
|
567 |
+
|
568 |
+
Args:
|
569 |
+
model: dit model
|
570 |
+
args: command line arguments
|
571 |
+
device: device to use
|
572 |
+
dit_dtype: dtype for the model
|
573 |
+
dit_weight_dtype: dtype for the model weights
|
574 |
+
"""
|
575 |
+
if args.fp8_scaled:
|
576 |
+
# load state dict as-is and optimize to fp8
|
577 |
+
state_dict = model.state_dict()
|
578 |
+
|
579 |
+
# if no blocks to swap, we can move the weights to GPU after optimization on GPU (omit redundant CPU->GPU copy)
|
580 |
+
move_to_device = args.blocks_to_swap == 0 # if blocks_to_swap > 0, we will keep the model on CPU
|
581 |
+
state_dict = model.fp8_optimization(state_dict, device, move_to_device, use_scaled_mm=args.fp8_fast)
|
582 |
+
|
583 |
+
info = model.load_state_dict(state_dict, strict=True, assign=True)
|
584 |
+
logger.info(f"Loaded FP8 optimized weights: {info}")
|
585 |
+
|
586 |
+
if args.blocks_to_swap == 0:
|
587 |
+
model.to(device) # make sure all parameters are on the right device (e.g. RoPE etc.)
|
588 |
+
else:
|
589 |
+
# simple cast to dit_dtype
|
590 |
+
target_dtype = None # load as-is (dit_weight_dtype == dtype of the weights in state_dict)
|
591 |
+
target_device = None
|
592 |
+
|
593 |
+
if dit_weight_dtype is not None: # in case of args.fp8 and not args.fp8_scaled
|
594 |
+
logger.info(f"Convert model to {dit_weight_dtype}")
|
595 |
+
target_dtype = dit_weight_dtype
|
596 |
+
|
597 |
+
if args.blocks_to_swap == 0:
|
598 |
+
logger.info(f"Move model to device: {device}")
|
599 |
+
target_device = device
|
600 |
+
|
601 |
+
model.to(target_device, target_dtype) # move and cast at the same time. this reduces redundant copy operations
|
602 |
+
|
603 |
+
if args.compile:
|
604 |
+
compile_backend, compile_mode, compile_dynamic, compile_fullgraph = args.compile_args
|
605 |
+
logger.info(
|
606 |
+
f"Torch Compiling[Backend: {compile_backend}; Mode: {compile_mode}; Dynamic: {compile_dynamic}; Fullgraph: {compile_fullgraph}]"
|
607 |
+
)
|
608 |
+
torch._dynamo.config.cache_size_limit = 32
|
609 |
+
for i in range(len(model.blocks)):
|
610 |
+
model.blocks[i] = torch.compile(
|
611 |
+
model.blocks[i],
|
612 |
+
backend=compile_backend,
|
613 |
+
mode=compile_mode,
|
614 |
+
dynamic=compile_dynamic.lower() in "true",
|
615 |
+
fullgraph=compile_fullgraph.lower() in "true",
|
616 |
+
)
|
617 |
+
|
618 |
+
if args.blocks_to_swap > 0:
|
619 |
+
logger.info(f"Enable swap {args.blocks_to_swap} blocks to CPU from device: {device}")
|
620 |
+
model.enable_block_swap(args.blocks_to_swap, device, supports_backward=False)
|
621 |
+
model.move_to_device_except_swap_blocks(device)
|
622 |
+
model.prepare_block_swap_before_forward()
|
623 |
+
else:
|
624 |
+
# make sure the model is on the right device
|
625 |
+
model.to(device)
|
626 |
+
|
627 |
+
model.eval().requires_grad_(False)
|
628 |
+
clean_memory_on_device(device)
|
629 |
+
|
630 |
+
|
631 |
+
def prepare_t2v_inputs(
|
632 |
+
args: argparse.Namespace,
|
633 |
+
config,
|
634 |
+
accelerator: Accelerator,
|
635 |
+
device: torch.device,
|
636 |
+
vae: Optional[WanVAE] = None,
|
637 |
+
encoded_context: Optional[Dict] = None,
|
638 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple[dict, dict]]:
|
639 |
+
"""Prepare inputs for T2V
|
640 |
+
|
641 |
+
Args:
|
642 |
+
args: command line arguments
|
643 |
+
config: model configuration
|
644 |
+
accelerator: Accelerator instance
|
645 |
+
device: device to use
|
646 |
+
vae: VAE model for control video encoding
|
647 |
+
encoded_context: Pre-encoded text context
|
648 |
+
|
649 |
+
Returns:
|
650 |
+
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple[dict, dict]]:
|
651 |
+
(noise, context, context_null, (arg_c, arg_null))
|
652 |
+
"""
|
653 |
+
# Prepare inputs for T2V
|
654 |
+
# calculate dimensions and sequence length
|
655 |
+
height, width = args.video_size
|
656 |
+
frames = args.video_length
|
657 |
+
(_, lat_f, lat_h, lat_w), seq_len = calculate_dimensions(args.video_size, args.video_length, config)
|
658 |
+
target_shape = (16, lat_f, lat_h, lat_w)
|
659 |
+
|
660 |
+
# configure negative prompt
|
661 |
+
n_prompt = args.negative_prompt if args.negative_prompt else config.sample_neg_prompt
|
662 |
+
|
663 |
+
# set seed
|
664 |
+
seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1)
|
665 |
+
if not args.cpu_noise:
|
666 |
+
seed_g = torch.Generator(device=device)
|
667 |
+
seed_g.manual_seed(seed)
|
668 |
+
else:
|
669 |
+
# ComfyUI compatible noise
|
670 |
+
seed_g = torch.manual_seed(seed)
|
671 |
+
|
672 |
+
if encoded_context is None:
|
673 |
+
# load text encoder
|
674 |
+
text_encoder = load_text_encoder(args, config, device)
|
675 |
+
text_encoder.model.to(device)
|
676 |
+
|
677 |
+
# encode prompt
|
678 |
+
with torch.no_grad():
|
679 |
+
if args.fp8_t5:
|
680 |
+
with torch.amp.autocast(device_type=device.type, dtype=config.t5_dtype):
|
681 |
+
context = text_encoder([args.prompt], device)
|
682 |
+
context_null = text_encoder([n_prompt], device)
|
683 |
+
else:
|
684 |
+
context = text_encoder([args.prompt], device)
|
685 |
+
context_null = text_encoder([n_prompt], device)
|
686 |
+
|
687 |
+
# free text encoder and clean memory
|
688 |
+
del text_encoder
|
689 |
+
clean_memory_on_device(device)
|
690 |
+
else:
|
691 |
+
# Use pre-encoded context
|
692 |
+
context = encoded_context["context"]
|
693 |
+
context_null = encoded_context["context_null"]
|
694 |
+
|
695 |
+
# Fun-Control: encode control video to latent space
|
696 |
+
if config.is_fun_control:
|
697 |
+
# TODO use same resizing as for image
|
698 |
+
logger.info(f"Encoding control video to latent space")
|
699 |
+
# C, F, H, W
|
700 |
+
control_video = load_control_video(args.control_path, frames, height, width).to(device)
|
701 |
+
vae.to_device(device)
|
702 |
+
with torch.autocast(device_type=device.type, dtype=vae.dtype), torch.no_grad():
|
703 |
+
control_latent = vae.encode([control_video])[0]
|
704 |
+
y = torch.concat([control_latent, torch.zeros_like(control_latent)], dim=0) # add control video latent
|
705 |
+
vae.to_device("cpu")
|
706 |
+
else:
|
707 |
+
y = None
|
708 |
+
|
709 |
+
# generate noise
|
710 |
+
noise = torch.randn(target_shape, dtype=torch.float32, generator=seed_g, device=device if not args.cpu_noise else "cpu")
|
711 |
+
noise = noise.to(device)
|
712 |
+
|
713 |
+
# prepare model input arguments
|
714 |
+
arg_c = {"context": context, "seq_len": seq_len}
|
715 |
+
arg_null = {"context": context_null, "seq_len": seq_len}
|
716 |
+
if y is not None:
|
717 |
+
arg_c["y"] = [y]
|
718 |
+
arg_null["y"] = [y]
|
719 |
+
|
720 |
+
return noise, context, context_null, (arg_c, arg_null)
|
721 |
+
|
722 |
+
|
723 |
+
def prepare_i2v_inputs(
|
724 |
+
args: argparse.Namespace,
|
725 |
+
config,
|
726 |
+
accelerator: Accelerator,
|
727 |
+
device: torch.device,
|
728 |
+
vae: WanVAE,
|
729 |
+
encoded_context: Optional[Dict] = None,
|
730 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[dict, dict]]:
|
731 |
+
"""Prepare inputs for I2V
|
732 |
+
|
733 |
+
Args:
|
734 |
+
args: command line arguments
|
735 |
+
config: model configuration
|
736 |
+
accelerator: Accelerator instance
|
737 |
+
device: device to use
|
738 |
+
vae: VAE model, used for image encoding
|
739 |
+
encoded_context: Pre-encoded text context
|
740 |
+
|
741 |
+
Returns:
|
742 |
+
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[dict, dict]]:
|
743 |
+
(noise, context, context_null, y, (arg_c, arg_null))
|
744 |
+
"""
|
745 |
+
# get video dimensions
|
746 |
+
height, width = args.video_size
|
747 |
+
frames = args.video_length
|
748 |
+
max_area = width * height
|
749 |
+
|
750 |
+
# load image
|
751 |
+
img = Image.open(args.image_path).convert("RGB")
|
752 |
+
|
753 |
+
# convert to numpy
|
754 |
+
img_cv2 = np.array(img) # PIL to numpy
|
755 |
+
|
756 |
+
# convert to tensor (-1 to 1)
|
757 |
+
img_tensor = TF.to_tensor(img).sub_(0.5).div_(0.5).to(device)
|
758 |
+
|
759 |
+
# end frame image
|
760 |
+
if args.end_image_path is not None:
|
761 |
+
end_img = Image.open(args.end_image_path).convert("RGB")
|
762 |
+
end_img_cv2 = np.array(end_img) # PIL to numpy
|
763 |
+
else:
|
764 |
+
end_img = None
|
765 |
+
end_img_cv2 = None
|
766 |
+
has_end_image = end_img is not None
|
767 |
+
|
768 |
+
# calculate latent dimensions: keep aspect ratio
|
769 |
+
height, width = img_tensor.shape[1:]
|
770 |
+
aspect_ratio = height / width
|
771 |
+
lat_h = round(np.sqrt(max_area * aspect_ratio) // config.vae_stride[1] // config.patch_size[1] * config.patch_size[1])
|
772 |
+
lat_w = round(np.sqrt(max_area / aspect_ratio) // config.vae_stride[2] // config.patch_size[2] * config.patch_size[2])
|
773 |
+
height = lat_h * config.vae_stride[1]
|
774 |
+
width = lat_w * config.vae_stride[2]
|
775 |
+
lat_f = (frames - 1) // config.vae_stride[0] + 1 # size of latent frames
|
776 |
+
max_seq_len = (lat_f + (1 if has_end_image else 0)) * lat_h * lat_w // (config.patch_size[1] * config.patch_size[2])
|
777 |
+
|
778 |
+
# set seed
|
779 |
+
seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1)
|
780 |
+
if not args.cpu_noise:
|
781 |
+
seed_g = torch.Generator(device=device)
|
782 |
+
seed_g.manual_seed(seed)
|
783 |
+
else:
|
784 |
+
# ComfyUI compatible noise
|
785 |
+
seed_g = torch.manual_seed(seed)
|
786 |
+
|
787 |
+
# generate noise
|
788 |
+
noise = torch.randn(
|
789 |
+
16,
|
790 |
+
lat_f + (1 if has_end_image else 0),
|
791 |
+
lat_h,
|
792 |
+
lat_w,
|
793 |
+
dtype=torch.float32,
|
794 |
+
generator=seed_g,
|
795 |
+
device=device if not args.cpu_noise else "cpu",
|
796 |
+
)
|
797 |
+
noise = noise.to(device)
|
798 |
+
|
799 |
+
# configure negative prompt
|
800 |
+
n_prompt = args.negative_prompt if args.negative_prompt else config.sample_neg_prompt
|
801 |
+
|
802 |
+
if encoded_context is None:
|
803 |
+
# load text encoder
|
804 |
+
text_encoder = load_text_encoder(args, config, device)
|
805 |
+
text_encoder.model.to(device)
|
806 |
+
|
807 |
+
# encode prompt
|
808 |
+
with torch.no_grad():
|
809 |
+
if args.fp8_t5:
|
810 |
+
with torch.amp.autocast(device_type=device.type, dtype=config.t5_dtype):
|
811 |
+
context = text_encoder([args.prompt], device)
|
812 |
+
context_null = text_encoder([n_prompt], device)
|
813 |
+
else:
|
814 |
+
context = text_encoder([args.prompt], device)
|
815 |
+
context_null = text_encoder([n_prompt], device)
|
816 |
+
|
817 |
+
# free text encoder and clean memory
|
818 |
+
del text_encoder
|
819 |
+
clean_memory_on_device(device)
|
820 |
+
|
821 |
+
# load CLIP model
|
822 |
+
clip = load_clip_model(args, config, device)
|
823 |
+
clip.model.to(device)
|
824 |
+
|
825 |
+
# encode image to CLIP context
|
826 |
+
logger.info(f"Encoding image to CLIP context")
|
827 |
+
with torch.amp.autocast(device_type=device.type, dtype=torch.float16), torch.no_grad():
|
828 |
+
clip_context = clip.visual([img_tensor[:, None, :, :]])
|
829 |
+
logger.info(f"Encoding complete")
|
830 |
+
|
831 |
+
# free CLIP model and clean memory
|
832 |
+
del clip
|
833 |
+
clean_memory_on_device(device)
|
834 |
+
else:
|
835 |
+
# Use pre-encoded context
|
836 |
+
context = encoded_context["context"]
|
837 |
+
context_null = encoded_context["context_null"]
|
838 |
+
clip_context = encoded_context["clip_context"]
|
839 |
+
|
840 |
+
# encode image to latent space with VAE
|
841 |
+
logger.info(f"Encoding image to latent space")
|
842 |
+
vae.to_device(device)
|
843 |
+
|
844 |
+
# resize image
|
845 |
+
interpolation = cv2.INTER_AREA if height < img_cv2.shape[0] else cv2.INTER_CUBIC
|
846 |
+
img_resized = cv2.resize(img_cv2, (width, height), interpolation=interpolation)
|
847 |
+
img_resized = TF.to_tensor(img_resized).sub_(0.5).div_(0.5).to(device) # -1 to 1, CHW
|
848 |
+
img_resized = img_resized.unsqueeze(1) # CFHW
|
849 |
+
|
850 |
+
if has_end_image:
|
851 |
+
interpolation = cv2.INTER_AREA if height < end_img_cv2.shape[1] else cv2.INTER_CUBIC
|
852 |
+
end_img_resized = cv2.resize(end_img_cv2, (width, height), interpolation=interpolation)
|
853 |
+
end_img_resized = TF.to_tensor(end_img_resized).sub_(0.5).div_(0.5).to(device) # -1 to 1, CHW
|
854 |
+
end_img_resized = end_img_resized.unsqueeze(1) # CFHW
|
855 |
+
|
856 |
+
# create mask for the first frame
|
857 |
+
msk = torch.zeros(4, lat_f + (1 if has_end_image else 0), lat_h, lat_w, device=device)
|
858 |
+
msk[:, 0] = 1
|
859 |
+
if has_end_image:
|
860 |
+
msk[:, -1] = 1
|
861 |
+
|
862 |
+
# encode image to latent space
|
863 |
+
with accelerator.autocast(), torch.no_grad():
|
864 |
+
# padding to match the required number of frames
|
865 |
+
padding_frames = frames - 1 # the first frame is image
|
866 |
+
img_resized = torch.concat([img_resized, torch.zeros(3, padding_frames, height, width, device=device)], dim=1)
|
867 |
+
y = vae.encode([img_resized])[0]
|
868 |
+
|
869 |
+
if has_end_image:
|
870 |
+
y_end = vae.encode([end_img_resized])[0]
|
871 |
+
y = torch.concat([y, y_end], dim=1) # add end frame
|
872 |
+
|
873 |
+
y = torch.concat([msk, y])
|
874 |
+
logger.info(f"Encoding complete")
|
875 |
+
|
876 |
+
# Fun-Control: encode control video to latent space
|
877 |
+
if config.is_fun_control:
|
878 |
+
# TODO use same resizing as for image
|
879 |
+
logger.info(f"Encoding control video to latent space")
|
880 |
+
# C, F, H, W
|
881 |
+
control_video = load_control_video(args.control_path, frames + (1 if has_end_image else 0), height, width).to(device)
|
882 |
+
with accelerator.autocast(), torch.no_grad():
|
883 |
+
control_latent = vae.encode([control_video])[0]
|
884 |
+
y = y[msk.shape[0] :] # remove mask because Fun-Control does not need it
|
885 |
+
if has_end_image:
|
886 |
+
y[:, 1:-1] = 0 # remove image latent except first and last frame. according to WanVideoWrapper, this doesn't work
|
887 |
+
else:
|
888 |
+
y[:, 1:] = 0 # remove image latent except first frame
|
889 |
+
y = torch.concat([control_latent, y], dim=0) # add control video latent
|
890 |
+
|
891 |
+
# prepare model input arguments
|
892 |
+
arg_c = {
|
893 |
+
"context": [context[0]],
|
894 |
+
"clip_fea": clip_context,
|
895 |
+
"seq_len": max_seq_len,
|
896 |
+
"y": [y],
|
897 |
+
}
|
898 |
+
|
899 |
+
arg_null = {
|
900 |
+
"context": context_null,
|
901 |
+
"clip_fea": clip_context,
|
902 |
+
"seq_len": max_seq_len,
|
903 |
+
"y": [y],
|
904 |
+
}
|
905 |
+
|
906 |
+
vae.to_device("cpu") # move VAE to CPU to save memory
|
907 |
+
clean_memory_on_device(device)
|
908 |
+
|
909 |
+
return noise, context, context_null, y, (arg_c, arg_null)
|
910 |
+
|
911 |
+
|
912 |
+
def load_control_video(control_path: str, frames: int, height: int, width: int) -> torch.Tensor:
|
913 |
+
"""load control video to latent space
|
914 |
+
|
915 |
+
Args:
|
916 |
+
control_path: path to control video
|
917 |
+
frames: number of frames in the video
|
918 |
+
height: height of the video
|
919 |
+
width: width of the video
|
920 |
+
|
921 |
+
Returns:
|
922 |
+
torch.Tensor: control video latent, CFHW
|
923 |
+
"""
|
924 |
+
logger.info(f"Load control video from {control_path}")
|
925 |
+
video = load_video(control_path, 0, frames, bucket_reso=(width, height)) # list of frames
|
926 |
+
if len(video) < frames:
|
927 |
+
raise ValueError(f"Video length is less than {frames}")
|
928 |
+
# video = np.stack(video, axis=0) # F, H, W, C
|
929 |
+
video = torch.stack([TF.to_tensor(frame).sub_(0.5).div_(0.5) for frame in video], dim=0) # F, C, H, W, -1 to 1
|
930 |
+
video = video.permute(1, 0, 2, 3) # C, F, H, W
|
931 |
+
return video
|
932 |
+
|
933 |
+
|
934 |
+
def setup_scheduler(args: argparse.Namespace, config, device: torch.device) -> Tuple[Any, torch.Tensor]:
|
935 |
+
"""setup scheduler for sampling
|
936 |
+
|
937 |
+
Args:
|
938 |
+
args: command line arguments
|
939 |
+
config: model configuration
|
940 |
+
device: device to use
|
941 |
+
|
942 |
+
Returns:
|
943 |
+
Tuple[Any, torch.Tensor]: (scheduler, timesteps)
|
944 |
+
"""
|
945 |
+
if args.sample_solver == "unipc":
|
946 |
+
scheduler = FlowUniPCMultistepScheduler(num_train_timesteps=config.num_train_timesteps, shift=1, use_dynamic_shifting=False)
|
947 |
+
scheduler.set_timesteps(args.infer_steps, device=device, shift=args.flow_shift)
|
948 |
+
timesteps = scheduler.timesteps
|
949 |
+
elif args.sample_solver == "dpm++":
|
950 |
+
scheduler = FlowDPMSolverMultistepScheduler(
|
951 |
+
num_train_timesteps=config.num_train_timesteps, shift=1, use_dynamic_shifting=False
|
952 |
+
)
|
953 |
+
sampling_sigmas = get_sampling_sigmas(args.infer_steps, args.flow_shift)
|
954 |
+
timesteps, _ = retrieve_timesteps(scheduler, device=device, sigmas=sampling_sigmas)
|
955 |
+
elif args.sample_solver == "vanilla":
|
956 |
+
scheduler = FlowMatchDiscreteScheduler(num_train_timesteps=config.num_train_timesteps, shift=args.flow_shift)
|
957 |
+
scheduler.set_timesteps(args.infer_steps, device=device)
|
958 |
+
timesteps = scheduler.timesteps
|
959 |
+
|
960 |
+
# FlowMatchDiscreteScheduler does not support generator argument in step method
|
961 |
+
org_step = scheduler.step
|
962 |
+
|
963 |
+
def step_wrapper(
|
964 |
+
model_output: torch.Tensor,
|
965 |
+
timestep: Union[int, torch.Tensor],
|
966 |
+
sample: torch.Tensor,
|
967 |
+
return_dict: bool = True,
|
968 |
+
generator=None,
|
969 |
+
):
|
970 |
+
return org_step(model_output, timestep, sample, return_dict=return_dict)
|
971 |
+
|
972 |
+
scheduler.step = step_wrapper
|
973 |
+
else:
|
974 |
+
raise NotImplementedError("Unsupported solver.")
|
975 |
+
|
976 |
+
return scheduler, timesteps
|
977 |
+
|
978 |
+
|
979 |
+
def run_sampling(
|
980 |
+
model: WanModel,
|
981 |
+
noise: torch.Tensor,
|
982 |
+
scheduler: Any,
|
983 |
+
timesteps: torch.Tensor,
|
984 |
+
args: argparse.Namespace,
|
985 |
+
inputs: Tuple[dict, dict],
|
986 |
+
device: torch.device,
|
987 |
+
seed_g: torch.Generator,
|
988 |
+
accelerator: Accelerator,
|
989 |
+
is_i2v: bool = False,
|
990 |
+
use_cpu_offload: bool = True,
|
991 |
+
) -> torch.Tensor:
|
992 |
+
"""run sampling
|
993 |
+
Args:
|
994 |
+
model: dit model
|
995 |
+
noise: initial noise
|
996 |
+
scheduler: scheduler for sampling
|
997 |
+
timesteps: time steps for sampling
|
998 |
+
args: command line arguments
|
999 |
+
inputs: model input (arg_c, arg_null)
|
1000 |
+
device: device to use
|
1001 |
+
seed_g: random generator
|
1002 |
+
accelerator: Accelerator instance
|
1003 |
+
is_i2v: I2V mode (False means T2V mode)
|
1004 |
+
use_cpu_offload: Whether to offload tensors to CPU during processing
|
1005 |
+
Returns:
|
1006 |
+
torch.Tensor: generated latent
|
1007 |
+
"""
|
1008 |
+
arg_c, arg_null = inputs
|
1009 |
+
|
1010 |
+
latent = noise
|
1011 |
+
latent_storage_device = device if not use_cpu_offload else "cpu"
|
1012 |
+
latent = latent.to(latent_storage_device)
|
1013 |
+
|
1014 |
+
# cfg skip
|
1015 |
+
apply_cfg_array = []
|
1016 |
+
num_timesteps = len(timesteps)
|
1017 |
+
|
1018 |
+
if args.cfg_skip_mode != "none" and args.cfg_apply_ratio is not None:
|
1019 |
+
# Calculate thresholds based on cfg_apply_ratio
|
1020 |
+
apply_steps = int(num_timesteps * args.cfg_apply_ratio)
|
1021 |
+
|
1022 |
+
if args.cfg_skip_mode == "early":
|
1023 |
+
# Skip CFG in early steps, apply in late steps
|
1024 |
+
start_index = num_timesteps - apply_steps
|
1025 |
+
end_index = num_timesteps
|
1026 |
+
elif args.cfg_skip_mode == "late":
|
1027 |
+
# Skip CFG in late steps, apply in early steps
|
1028 |
+
start_index = 0
|
1029 |
+
end_index = apply_steps
|
1030 |
+
elif args.cfg_skip_mode == "early_late":
|
1031 |
+
# Skip CFG in early and late steps, apply in middle steps
|
1032 |
+
start_index = (num_timesteps - apply_steps) // 2
|
1033 |
+
end_index = start_index + apply_steps
|
1034 |
+
elif args.cfg_skip_mode == "middle":
|
1035 |
+
# Skip CFG in middle steps, apply in early and late steps
|
1036 |
+
skip_steps = num_timesteps - apply_steps
|
1037 |
+
middle_start = (num_timesteps - skip_steps) // 2
|
1038 |
+
middle_end = middle_start + skip_steps
|
1039 |
+
|
1040 |
+
w = 0.0
|
1041 |
+
for step_idx in range(num_timesteps):
|
1042 |
+
if args.cfg_skip_mode == "alternate":
|
1043 |
+
# accumulate w and apply CFG when w >= 1.0
|
1044 |
+
w += args.cfg_apply_ratio
|
1045 |
+
apply = w >= 1.0
|
1046 |
+
if apply:
|
1047 |
+
w -= 1.0
|
1048 |
+
elif args.cfg_skip_mode == "middle":
|
1049 |
+
# Skip CFG in early and late steps, apply in middle steps
|
1050 |
+
apply = step_idx < middle_start or step_idx >= middle_end
|
1051 |
+
else:
|
1052 |
+
# Apply CFG on some steps based on ratio
|
1053 |
+
apply = step_idx >= start_index and step_idx < end_index
|
1054 |
+
|
1055 |
+
apply_cfg_array.append(apply)
|
1056 |
+
|
1057 |
+
pattern = ["A" if apply else "S" for apply in apply_cfg_array]
|
1058 |
+
pattern = "".join(pattern)
|
1059 |
+
logger.info(f"CFG skip mode: {args.cfg_skip_mode}, apply ratio: {args.cfg_apply_ratio}, pattern: {pattern}")
|
1060 |
+
else:
|
1061 |
+
# Apply CFG on all steps
|
1062 |
+
apply_cfg_array = [True] * num_timesteps
|
1063 |
+
|
1064 |
+
# SLG original implementation is based on https://github.com/Stability-AI/sd3.5/blob/main/sd3_impls.py
|
1065 |
+
slg_start_step = int(args.slg_start * num_timesteps)
|
1066 |
+
slg_end_step = int(args.slg_end * num_timesteps)
|
1067 |
+
|
1068 |
+
for i, t in enumerate(tqdm(timesteps)):
|
1069 |
+
# latent is on CPU if use_cpu_offload is True
|
1070 |
+
latent_model_input = [latent.to(device)]
|
1071 |
+
timestep = torch.stack([t]).to(device)
|
1072 |
+
|
1073 |
+
with accelerator.autocast(), torch.no_grad():
|
1074 |
+
noise_pred_cond = model(latent_model_input, t=timestep, **arg_c)[0].to(latent_storage_device)
|
1075 |
+
|
1076 |
+
apply_cfg = apply_cfg_array[i] # apply CFG or not
|
1077 |
+
if apply_cfg:
|
1078 |
+
apply_slg = i >= slg_start_step and i < slg_end_step
|
1079 |
+
# print(f"Applying SLG: {apply_slg}, i: {i}, slg_start_step: {slg_start_step}, slg_end_step: {slg_end_step}")
|
1080 |
+
if args.slg_mode == "original" and apply_slg:
|
1081 |
+
noise_pred_uncond = model(latent_model_input, t=timestep, **arg_null)[0].to(latent_storage_device)
|
1082 |
+
|
1083 |
+
# apply guidance
|
1084 |
+
# SD3 formula: scaled = neg_out + (pos_out - neg_out) * cond_scale
|
1085 |
+
noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
1086 |
+
|
1087 |
+
# calculate skip layer out
|
1088 |
+
skip_layer_out = model(latent_model_input, t=timestep, skip_block_indices=args.slg_layers, **arg_null)[0].to(
|
1089 |
+
latent_storage_device
|
1090 |
+
)
|
1091 |
+
|
1092 |
+
# apply skip layer guidance
|
1093 |
+
# SD3 formula: scaled = scaled + (pos_out - skip_layer_out) * self.slg
|
1094 |
+
noise_pred = noise_pred + args.slg_scale * (noise_pred_cond - skip_layer_out)
|
1095 |
+
elif args.slg_mode == "uncond" and apply_slg:
|
1096 |
+
# noise_pred_uncond is skip layer out
|
1097 |
+
noise_pred_uncond = model(latent_model_input, t=timestep, skip_block_indices=args.slg_layers, **arg_null)[0].to(
|
1098 |
+
latent_storage_device
|
1099 |
+
)
|
1100 |
+
|
1101 |
+
# apply guidance
|
1102 |
+
noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
1103 |
+
|
1104 |
+
else:
|
1105 |
+
# normal guidance
|
1106 |
+
noise_pred_uncond = model(latent_model_input, t=timestep, **arg_null)[0].to(latent_storage_device)
|
1107 |
+
|
1108 |
+
# apply guidance
|
1109 |
+
noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
1110 |
+
else:
|
1111 |
+
noise_pred = noise_pred_cond
|
1112 |
+
|
1113 |
+
# step
|
1114 |
+
latent_input = latent.unsqueeze(0)
|
1115 |
+
temp_x0 = scheduler.step(noise_pred.unsqueeze(0), t, latent_input, return_dict=False, generator=seed_g)[0]
|
1116 |
+
|
1117 |
+
# update latent
|
1118 |
+
latent = temp_x0.squeeze(0)
|
1119 |
+
|
1120 |
+
return latent
|
1121 |
+
|
1122 |
+
|
1123 |
+
def generate(args: argparse.Namespace, gen_settings: GenerationSettings, shared_models: Optional[Dict] = None) -> torch.Tensor:
|
1124 |
+
"""main function for generation
|
1125 |
+
|
1126 |
+
Args:
|
1127 |
+
args: command line arguments
|
1128 |
+
shared_models: dictionary containing pre-loaded models and encoded data
|
1129 |
+
|
1130 |
+
Returns:
|
1131 |
+
torch.Tensor: generated latent
|
1132 |
+
"""
|
1133 |
+
device, cfg, dit_dtype, dit_weight_dtype, vae_dtype = (
|
1134 |
+
gen_settings.device,
|
1135 |
+
gen_settings.cfg,
|
1136 |
+
gen_settings.dit_dtype,
|
1137 |
+
gen_settings.dit_weight_dtype,
|
1138 |
+
gen_settings.vae_dtype,
|
1139 |
+
)
|
1140 |
+
|
1141 |
+
# prepare accelerator
|
1142 |
+
mixed_precision = "bf16" if dit_dtype == torch.bfloat16 else "fp16"
|
1143 |
+
accelerator = accelerate.Accelerator(mixed_precision=mixed_precision)
|
1144 |
+
|
1145 |
+
# I2V or T2V
|
1146 |
+
is_i2v = "i2v" in args.task
|
1147 |
+
|
1148 |
+
# prepare seed
|
1149 |
+
seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1)
|
1150 |
+
args.seed = seed # set seed to args for saving
|
1151 |
+
|
1152 |
+
# Check if we have shared models
|
1153 |
+
if shared_models is not None:
|
1154 |
+
# Use shared models and encoded data
|
1155 |
+
vae = shared_models.get("vae")
|
1156 |
+
model = shared_models.get("model")
|
1157 |
+
encoded_context = shared_models.get("encoded_contexts", {}).get(args.prompt)
|
1158 |
+
|
1159 |
+
# prepare inputs
|
1160 |
+
if is_i2v:
|
1161 |
+
# I2V
|
1162 |
+
noise, context, context_null, y, inputs = prepare_i2v_inputs(args, cfg, accelerator, device, vae, encoded_context)
|
1163 |
+
else:
|
1164 |
+
# T2V
|
1165 |
+
noise, context, context_null, inputs = prepare_t2v_inputs(args, cfg, accelerator, device, vae, encoded_context)
|
1166 |
+
else:
|
1167 |
+
# prepare inputs without shared models
|
1168 |
+
if is_i2v:
|
1169 |
+
# I2V: need text encoder, VAE and CLIP
|
1170 |
+
vae = load_vae(args, cfg, device, vae_dtype)
|
1171 |
+
noise, context, context_null, y, inputs = prepare_i2v_inputs(args, cfg, accelerator, device, vae)
|
1172 |
+
# vae is on CPU after prepare_i2v_inputs
|
1173 |
+
else:
|
1174 |
+
# T2V: need text encoder
|
1175 |
+
vae = None
|
1176 |
+
if cfg.is_fun_control:
|
1177 |
+
# Fun-Control: need VAE for encoding control video
|
1178 |
+
vae = load_vae(args, cfg, device, vae_dtype)
|
1179 |
+
noise, context, context_null, inputs = prepare_t2v_inputs(args, cfg, accelerator, device, vae)
|
1180 |
+
|
1181 |
+
# load DiT model
|
1182 |
+
model = load_dit_model(args, cfg, device, dit_dtype, dit_weight_dtype, is_i2v)
|
1183 |
+
|
1184 |
+
# merge LoRA weights
|
1185 |
+
if args.lora_weight is not None and len(args.lora_weight) > 0:
|
1186 |
+
merge_lora_weights(lora_wan, model, args, device)
|
1187 |
+
|
1188 |
+
# if we only want to save the model, we can skip the rest
|
1189 |
+
if args.save_merged_model:
|
1190 |
+
return None
|
1191 |
+
|
1192 |
+
# optimize model: fp8 conversion, block swap etc.
|
1193 |
+
optimize_model(model, args, device, dit_dtype, dit_weight_dtype)
|
1194 |
+
|
1195 |
+
# setup scheduler
|
1196 |
+
scheduler, timesteps = setup_scheduler(args, cfg, device)
|
1197 |
+
|
1198 |
+
# set random generator
|
1199 |
+
seed_g = torch.Generator(device=device)
|
1200 |
+
seed_g.manual_seed(seed)
|
1201 |
+
|
1202 |
+
# run sampling
|
1203 |
+
latent = run_sampling(model, noise, scheduler, timesteps, args, inputs, device, seed_g, accelerator, is_i2v)
|
1204 |
+
|
1205 |
+
# Only clean up shared models if they were created within this function
|
1206 |
+
if shared_models is None:
|
1207 |
+
# free memory
|
1208 |
+
del model
|
1209 |
+
del scheduler
|
1210 |
+
synchronize_device(device)
|
1211 |
+
|
1212 |
+
# wait for 5 seconds until block swap is done
|
1213 |
+
logger.info("Waiting for 5 seconds to finish block swap")
|
1214 |
+
time.sleep(5)
|
1215 |
+
|
1216 |
+
gc.collect()
|
1217 |
+
clean_memory_on_device(device)
|
1218 |
+
|
1219 |
+
# save VAE model for decoding
|
1220 |
+
if vae is None:
|
1221 |
+
args._vae = None
|
1222 |
+
else:
|
1223 |
+
args._vae = vae
|
1224 |
+
|
1225 |
+
return latent
|
1226 |
+
|
1227 |
+
|
1228 |
+
def decode_latent(latent: torch.Tensor, args: argparse.Namespace, cfg) -> torch.Tensor:
|
1229 |
+
"""decode latent
|
1230 |
+
|
1231 |
+
Args:
|
1232 |
+
latent: latent tensor
|
1233 |
+
args: command line arguments
|
1234 |
+
cfg: model configuration
|
1235 |
+
|
1236 |
+
Returns:
|
1237 |
+
torch.Tensor: decoded video or image
|
1238 |
+
"""
|
1239 |
+
device = torch.device(args.device)
|
1240 |
+
|
1241 |
+
# load VAE model or use the one from the generation
|
1242 |
+
vae_dtype = str_to_dtype(args.vae_dtype) if args.vae_dtype is not None else torch.bfloat16
|
1243 |
+
if hasattr(args, "_vae") and args._vae is not None:
|
1244 |
+
vae = args._vae
|
1245 |
+
else:
|
1246 |
+
vae = load_vae(args, cfg, device, vae_dtype)
|
1247 |
+
|
1248 |
+
vae.to_device(device)
|
1249 |
+
|
1250 |
+
logger.info(f"Decoding video from latents: {latent.shape}")
|
1251 |
+
x0 = latent.to(device)
|
1252 |
+
|
1253 |
+
with torch.autocast(device_type=device.type, dtype=vae_dtype), torch.no_grad():
|
1254 |
+
videos = vae.decode(x0)
|
1255 |
+
|
1256 |
+
# some tail frames may be corrupted when end frame is used, we add an option to remove them
|
1257 |
+
if args.trim_tail_frames:
|
1258 |
+
videos[0] = videos[0][:, : -args.trim_tail_frames]
|
1259 |
+
|
1260 |
+
logger.info(f"Decoding complete")
|
1261 |
+
video = videos[0]
|
1262 |
+
del videos
|
1263 |
+
video = video.to(torch.float32).cpu()
|
1264 |
+
|
1265 |
+
return video
|
1266 |
+
|
1267 |
+
|
1268 |
+
def save_latent(latent: torch.Tensor, args: argparse.Namespace, height: int, width: int) -> str:
|
1269 |
+
"""Save latent to file
|
1270 |
+
|
1271 |
+
Args:
|
1272 |
+
latent: latent tensor
|
1273 |
+
args: command line arguments
|
1274 |
+
height: height of frame
|
1275 |
+
width: width of frame
|
1276 |
+
|
1277 |
+
Returns:
|
1278 |
+
str: Path to saved latent file
|
1279 |
+
"""
|
1280 |
+
save_path = args.save_path
|
1281 |
+
os.makedirs(save_path, exist_ok=True)
|
1282 |
+
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S")
|
1283 |
+
|
1284 |
+
seed = args.seed
|
1285 |
+
video_length = args.video_length
|
1286 |
+
latent_path = f"{save_path}/{time_flag}_{seed}_latent.safetensors"
|
1287 |
+
|
1288 |
+
if args.no_metadata:
|
1289 |
+
metadata = None
|
1290 |
+
else:
|
1291 |
+
metadata = {
|
1292 |
+
"seeds": f"{seed}",
|
1293 |
+
"prompt": f"{args.prompt}",
|
1294 |
+
"height": f"{height}",
|
1295 |
+
"width": f"{width}",
|
1296 |
+
"video_length": f"{video_length}",
|
1297 |
+
"infer_steps": f"{args.infer_steps}",
|
1298 |
+
"guidance_scale": f"{args.guidance_scale}",
|
1299 |
+
}
|
1300 |
+
if args.negative_prompt is not None:
|
1301 |
+
metadata["negative_prompt"] = f"{args.negative_prompt}"
|
1302 |
+
|
1303 |
+
sd = {"latent": latent}
|
1304 |
+
save_file(sd, latent_path, metadata=metadata)
|
1305 |
+
logger.info(f"Latent saved to: {latent_path}")
|
1306 |
+
|
1307 |
+
return latent_path
|
1308 |
+
|
1309 |
+
|
1310 |
+
def save_video(video: torch.Tensor, args: argparse.Namespace, original_base_name: Optional[str] = None) -> str:
|
1311 |
+
"""Save video to file
|
1312 |
+
|
1313 |
+
Args:
|
1314 |
+
video: Video tensor
|
1315 |
+
args: command line arguments
|
1316 |
+
original_base_name: Original base name (if latents are loaded from files)
|
1317 |
+
|
1318 |
+
Returns:
|
1319 |
+
str: Path to saved video file
|
1320 |
+
"""
|
1321 |
+
save_path = args.save_path
|
1322 |
+
os.makedirs(save_path, exist_ok=True)
|
1323 |
+
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S")
|
1324 |
+
|
1325 |
+
seed = args.seed
|
1326 |
+
original_name = "" if original_base_name is None else f"_{original_base_name}"
|
1327 |
+
video_path = f"{save_path}/{time_flag}_{seed}{original_name}.mp4"
|
1328 |
+
|
1329 |
+
video = video.unsqueeze(0)
|
1330 |
+
save_videos_grid(video, video_path, fps=args.fps, rescale=True)
|
1331 |
+
logger.info(f"Video saved to: {video_path}")
|
1332 |
+
|
1333 |
+
return video_path
|
1334 |
+
|
1335 |
+
|
1336 |
+
def save_images(sample: torch.Tensor, args: argparse.Namespace, original_base_name: Optional[str] = None) -> str:
|
1337 |
+
"""Save images to directory
|
1338 |
+
|
1339 |
+
Args:
|
1340 |
+
sample: Video tensor
|
1341 |
+
args: command line arguments
|
1342 |
+
original_base_name: Original base name (if latents are loaded from files)
|
1343 |
+
|
1344 |
+
Returns:
|
1345 |
+
str: Path to saved images directory
|
1346 |
+
"""
|
1347 |
+
save_path = args.save_path
|
1348 |
+
os.makedirs(save_path, exist_ok=True)
|
1349 |
+
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S")
|
1350 |
+
|
1351 |
+
seed = args.seed
|
1352 |
+
original_name = "" if original_base_name is None else f"_{original_base_name}"
|
1353 |
+
image_name = f"{time_flag}_{seed}{original_name}"
|
1354 |
+
sample = sample.unsqueeze(0)
|
1355 |
+
save_images_grid(sample, save_path, image_name, rescale=True)
|
1356 |
+
logger.info(f"Sample images saved to: {save_path}/{image_name}")
|
1357 |
+
|
1358 |
+
return f"{save_path}/{image_name}"
|
1359 |
+
|
1360 |
+
|
1361 |
+
def save_output(
|
1362 |
+
latent: torch.Tensor, args: argparse.Namespace, cfg, height: int, width: int, original_base_names: Optional[List[str]] = None
|
1363 |
+
) -> None:
|
1364 |
+
"""save output
|
1365 |
+
|
1366 |
+
Args:
|
1367 |
+
latent: latent tensor
|
1368 |
+
args: command line arguments
|
1369 |
+
cfg: model configuration
|
1370 |
+
height: height of frame
|
1371 |
+
width: width of frame
|
1372 |
+
original_base_names: original base names (if latents are loaded from files)
|
1373 |
+
"""
|
1374 |
+
if args.output_type == "latent" or args.output_type == "both":
|
1375 |
+
# save latent
|
1376 |
+
save_latent(latent, args, height, width)
|
1377 |
+
|
1378 |
+
if args.output_type == "video" or args.output_type == "both":
|
1379 |
+
# save video
|
1380 |
+
sample = decode_latent(latent.unsqueeze(0), args, cfg)
|
1381 |
+
original_name = "" if original_base_names is None else f"_{original_base_names[0]}"
|
1382 |
+
save_video(sample, args, original_name)
|
1383 |
+
|
1384 |
+
elif args.output_type == "images":
|
1385 |
+
# save images
|
1386 |
+
sample = decode_latent(latent.unsqueeze(0), args, cfg)
|
1387 |
+
original_name = "" if original_base_names is None else f"_{original_base_names[0]}"
|
1388 |
+
save_images(sample, args, original_name)
|
1389 |
+
|
1390 |
+
|
1391 |
+
def preprocess_prompts_for_batch(prompt_lines: List[str], base_args: argparse.Namespace) -> List[Dict]:
|
1392 |
+
"""Process multiple prompts for batch mode
|
1393 |
+
|
1394 |
+
Args:
|
1395 |
+
prompt_lines: List of prompt lines
|
1396 |
+
base_args: Base command line arguments
|
1397 |
+
|
1398 |
+
Returns:
|
1399 |
+
List[Dict]: List of prompt data dictionaries
|
1400 |
+
"""
|
1401 |
+
prompts_data = []
|
1402 |
+
|
1403 |
+
for line in prompt_lines:
|
1404 |
+
line = line.strip()
|
1405 |
+
if not line or line.startswith("#"): # Skip empty lines and comments
|
1406 |
+
continue
|
1407 |
+
|
1408 |
+
# Parse prompt line and create override dictionary
|
1409 |
+
prompt_data = parse_prompt_line(line)
|
1410 |
+
logger.info(f"Parsed prompt data: {prompt_data}")
|
1411 |
+
prompts_data.append(prompt_data)
|
1412 |
+
|
1413 |
+
return prompts_data
|
1414 |
+
|
1415 |
+
|
1416 |
+
def process_batch_prompts(prompts_data: List[Dict], args: argparse.Namespace) -> None:
|
1417 |
+
"""Process multiple prompts with model reuse
|
1418 |
+
|
1419 |
+
Args:
|
1420 |
+
prompts_data: List of prompt data dictionaries
|
1421 |
+
args: Base command line arguments
|
1422 |
+
"""
|
1423 |
+
if not prompts_data:
|
1424 |
+
logger.warning("No valid prompts found")
|
1425 |
+
return
|
1426 |
+
|
1427 |
+
# 1. Load configuration
|
1428 |
+
gen_settings = get_generation_settings(args)
|
1429 |
+
device, cfg, dit_dtype, dit_weight_dtype, vae_dtype = (
|
1430 |
+
gen_settings.device,
|
1431 |
+
gen_settings.cfg,
|
1432 |
+
gen_settings.dit_dtype,
|
1433 |
+
gen_settings.dit_weight_dtype,
|
1434 |
+
gen_settings.vae_dtype,
|
1435 |
+
)
|
1436 |
+
is_i2v = "i2v" in args.task
|
1437 |
+
|
1438 |
+
# 2. Encode all prompts
|
1439 |
+
logger.info("Loading text encoder to encode all prompts")
|
1440 |
+
text_encoder = load_text_encoder(args, cfg, device)
|
1441 |
+
text_encoder.model.to(device)
|
1442 |
+
|
1443 |
+
encoded_contexts = {}
|
1444 |
+
|
1445 |
+
with torch.no_grad():
|
1446 |
+
for prompt_data in prompts_data:
|
1447 |
+
prompt = prompt_data["prompt"]
|
1448 |
+
prompt_args = apply_overrides(args, prompt_data)
|
1449 |
+
n_prompt = prompt_data.get(
|
1450 |
+
"negative_prompt", prompt_args.negative_prompt if prompt_args.negative_prompt else cfg.sample_neg_prompt
|
1451 |
+
)
|
1452 |
+
|
1453 |
+
if args.fp8_t5:
|
1454 |
+
with torch.amp.autocast(device_type=device.type, dtype=cfg.t5_dtype):
|
1455 |
+
context = text_encoder([prompt], device)
|
1456 |
+
context_null = text_encoder([n_prompt], device)
|
1457 |
+
else:
|
1458 |
+
context = text_encoder([prompt], device)
|
1459 |
+
context_null = text_encoder([n_prompt], device)
|
1460 |
+
|
1461 |
+
encoded_contexts[prompt] = {"context": context, "context_null": context_null}
|
1462 |
+
|
1463 |
+
# Free text encoder and clean memory
|
1464 |
+
del text_encoder
|
1465 |
+
clean_memory_on_device(device)
|
1466 |
+
|
1467 |
+
# 3. Process I2V additional encodings if needed
|
1468 |
+
vae = None
|
1469 |
+
if is_i2v:
|
1470 |
+
logger.info("Loading VAE and CLIP for I2V preprocessing")
|
1471 |
+
vae = load_vae(args, cfg, device, vae_dtype)
|
1472 |
+
vae.to_device(device)
|
1473 |
+
|
1474 |
+
clip = load_clip_model(args, cfg, device)
|
1475 |
+
clip.model.to(device)
|
1476 |
+
|
1477 |
+
# Process each image and encode with CLIP
|
1478 |
+
for prompt_data in prompts_data:
|
1479 |
+
if "image_path" not in prompt_data:
|
1480 |
+
continue
|
1481 |
+
|
1482 |
+
prompt_args = apply_overrides(args, prompt_data)
|
1483 |
+
if not os.path.exists(prompt_args.image_path):
|
1484 |
+
logger.warning(f"Image path not found: {prompt_args.image_path}")
|
1485 |
+
continue
|
1486 |
+
|
1487 |
+
# Load and encode image with CLIP
|
1488 |
+
img = Image.open(prompt_args.image_path).convert("RGB")
|
1489 |
+
img_tensor = TF.to_tensor(img).sub_(0.5).div_(0.5).to(device)
|
1490 |
+
|
1491 |
+
with torch.amp.autocast(device_type=device.type, dtype=torch.float16), torch.no_grad():
|
1492 |
+
clip_context = clip.visual([img_tensor[:, None, :, :]])
|
1493 |
+
|
1494 |
+
encoded_contexts[prompt_data["prompt"]]["clip_context"] = clip_context
|
1495 |
+
|
1496 |
+
# Free CLIP and clean memory
|
1497 |
+
del clip
|
1498 |
+
clean_memory_on_device(device)
|
1499 |
+
|
1500 |
+
# Keep VAE in CPU memory for later use
|
1501 |
+
vae.to_device("cpu")
|
1502 |
+
elif cfg.is_fun_control:
|
1503 |
+
# For Fun-Control, we need VAE but keep it on CPU
|
1504 |
+
vae = load_vae(args, cfg, device, vae_dtype)
|
1505 |
+
vae.to_device("cpu")
|
1506 |
+
|
1507 |
+
# 4. Load DiT model
|
1508 |
+
logger.info("Loading DiT model")
|
1509 |
+
model = load_dit_model(args, cfg, device, dit_dtype, dit_weight_dtype, is_i2v)
|
1510 |
+
|
1511 |
+
# 5. Merge LoRA weights if needed
|
1512 |
+
if args.lora_weight is not None and len(args.lora_weight) > 0:
|
1513 |
+
merge_lora_weights(lora_wan, model, args, device)
|
1514 |
+
if args.save_merged_model:
|
1515 |
+
logger.info("Model merged and saved. Exiting.")
|
1516 |
+
return
|
1517 |
+
|
1518 |
+
# 6. Optimize model
|
1519 |
+
optimize_model(model, args, device, dit_dtype, dit_weight_dtype)
|
1520 |
+
|
1521 |
+
# Create shared models dict for generate function
|
1522 |
+
shared_models = {"vae": vae, "model": model, "encoded_contexts": encoded_contexts}
|
1523 |
+
|
1524 |
+
# 7. Generate for each prompt
|
1525 |
+
all_latents = []
|
1526 |
+
all_prompt_args = []
|
1527 |
+
|
1528 |
+
for i, prompt_data in enumerate(prompts_data):
|
1529 |
+
logger.info(f"Processing prompt {i+1}/{len(prompts_data)}: {prompt_data['prompt'][:50]}...")
|
1530 |
+
|
1531 |
+
# Apply overrides for this prompt
|
1532 |
+
prompt_args = apply_overrides(args, prompt_data)
|
1533 |
+
|
1534 |
+
# Generate latent
|
1535 |
+
latent = generate(prompt_args, gen_settings, shared_models)
|
1536 |
+
|
1537 |
+
# Save latent if needed
|
1538 |
+
height, width, _ = check_inputs(prompt_args)
|
1539 |
+
if prompt_args.output_type == "latent" or prompt_args.output_type == "both":
|
1540 |
+
save_latent(latent, prompt_args, height, width)
|
1541 |
+
|
1542 |
+
all_latents.append(latent)
|
1543 |
+
all_prompt_args.append(prompt_args)
|
1544 |
+
|
1545 |
+
# 8. Free DiT model
|
1546 |
+
del model
|
1547 |
+
clean_memory_on_device(device)
|
1548 |
+
synchronize_device(device)
|
1549 |
+
|
1550 |
+
# wait for 5 seconds until block swap is done
|
1551 |
+
logger.info("Waiting for 5 seconds to finish block swap")
|
1552 |
+
time.sleep(5)
|
1553 |
+
|
1554 |
+
gc.collect()
|
1555 |
+
clean_memory_on_device(device)
|
1556 |
+
|
1557 |
+
# 9. Decode latents if needed
|
1558 |
+
if args.output_type != "latent":
|
1559 |
+
logger.info("Decoding latents to videos/images")
|
1560 |
+
|
1561 |
+
if vae is None:
|
1562 |
+
vae = load_vae(args, cfg, device, vae_dtype)
|
1563 |
+
|
1564 |
+
vae.to_device(device)
|
1565 |
+
|
1566 |
+
for i, (latent, prompt_args) in enumerate(zip(all_latents, all_prompt_args)):
|
1567 |
+
logger.info(f"Decoding output {i+1}/{len(all_latents)}")
|
1568 |
+
|
1569 |
+
# Decode latent
|
1570 |
+
video = decode_latent(latent.unsqueeze(0), prompt_args, cfg)
|
1571 |
+
|
1572 |
+
# Save as video or images
|
1573 |
+
if prompt_args.output_type == "video" or prompt_args.output_type == "both":
|
1574 |
+
save_video(video, prompt_args)
|
1575 |
+
elif prompt_args.output_type == "images":
|
1576 |
+
save_images(video, prompt_args)
|
1577 |
+
|
1578 |
+
# Free VAE
|
1579 |
+
del vae
|
1580 |
+
|
1581 |
+
clean_memory_on_device(device)
|
1582 |
+
gc.collect()
|
1583 |
+
|
1584 |
+
|
1585 |
+
def process_interactive(args: argparse.Namespace) -> None:
|
1586 |
+
"""Process prompts in interactive mode
|
1587 |
+
|
1588 |
+
Args:
|
1589 |
+
args: Base command line arguments
|
1590 |
+
"""
|
1591 |
+
gen_settings = get_generation_settings(args)
|
1592 |
+
device, cfg, dit_dtype, dit_weight_dtype, vae_dtype = (
|
1593 |
+
gen_settings.device,
|
1594 |
+
gen_settings.cfg,
|
1595 |
+
gen_settings.dit_dtype,
|
1596 |
+
gen_settings.dit_weight_dtype,
|
1597 |
+
gen_settings.vae_dtype,
|
1598 |
+
)
|
1599 |
+
is_i2v = "i2v" in args.task
|
1600 |
+
|
1601 |
+
# Initialize models to None
|
1602 |
+
text_encoder = None
|
1603 |
+
vae = None
|
1604 |
+
model = None
|
1605 |
+
clip = None
|
1606 |
+
|
1607 |
+
print("Interactive mode. Enter prompts (Ctrl+D to exit):")
|
1608 |
+
|
1609 |
+
try:
|
1610 |
+
while True:
|
1611 |
+
try:
|
1612 |
+
line = input("> ")
|
1613 |
+
if not line.strip():
|
1614 |
+
continue
|
1615 |
+
|
1616 |
+
# Parse prompt
|
1617 |
+
prompt_data = parse_prompt_line(line)
|
1618 |
+
prompt_args = apply_overrides(args, prompt_data)
|
1619 |
+
|
1620 |
+
# Ensure we have all the models we need
|
1621 |
+
|
1622 |
+
# 1. Load text encoder if not already loaded
|
1623 |
+
if text_encoder is None:
|
1624 |
+
logger.info("Loading text encoder")
|
1625 |
+
text_encoder = load_text_encoder(args, cfg, device)
|
1626 |
+
|
1627 |
+
text_encoder.model.to(device)
|
1628 |
+
|
1629 |
+
# Encode prompt
|
1630 |
+
n_prompt = prompt_data.get(
|
1631 |
+
"negative_prompt", prompt_args.negative_prompt if prompt_args.negative_prompt else cfg.sample_neg_prompt
|
1632 |
+
)
|
1633 |
+
|
1634 |
+
with torch.no_grad():
|
1635 |
+
if args.fp8_t5:
|
1636 |
+
with torch.amp.autocast(device_type=device.type, dtype=cfg.t5_dtype):
|
1637 |
+
context = text_encoder([prompt_data["prompt"]], device)
|
1638 |
+
context_null = text_encoder([n_prompt], device)
|
1639 |
+
else:
|
1640 |
+
context = text_encoder([prompt_data["prompt"]], device)
|
1641 |
+
context_null = text_encoder([n_prompt], device)
|
1642 |
+
|
1643 |
+
encoded_context = {"context": context, "context_null": context_null}
|
1644 |
+
|
1645 |
+
# Move text encoder to CPU after use
|
1646 |
+
text_encoder.model.to("cpu")
|
1647 |
+
|
1648 |
+
# 2. For I2V, we need CLIP and VAE
|
1649 |
+
if is_i2v:
|
1650 |
+
if clip is None:
|
1651 |
+
logger.info("Loading CLIP model")
|
1652 |
+
clip = load_clip_model(args, cfg, device)
|
1653 |
+
|
1654 |
+
clip.model.to(device)
|
1655 |
+
|
1656 |
+
# Encode image with CLIP if there's an image path
|
1657 |
+
if prompt_args.image_path and os.path.exists(prompt_args.image_path):
|
1658 |
+
img = Image.open(prompt_args.image_path).convert("RGB")
|
1659 |
+
img_tensor = TF.to_tensor(img).sub_(0.5).div_(0.5).to(device)
|
1660 |
+
|
1661 |
+
with torch.amp.autocast(device_type=device.type, dtype=torch.float16), torch.no_grad():
|
1662 |
+
clip_context = clip.visual([img_tensor[:, None, :, :]])
|
1663 |
+
|
1664 |
+
encoded_context["clip_context"] = clip_context
|
1665 |
+
|
1666 |
+
# Move CLIP to CPU after use
|
1667 |
+
clip.model.to("cpu")
|
1668 |
+
|
1669 |
+
# Load VAE if needed
|
1670 |
+
if vae is None:
|
1671 |
+
logger.info("Loading VAE model")
|
1672 |
+
vae = load_vae(args, cfg, device, vae_dtype)
|
1673 |
+
elif cfg.is_fun_control and vae is None:
|
1674 |
+
# For Fun-Control, we need VAE
|
1675 |
+
logger.info("Loading VAE model for Fun-Control")
|
1676 |
+
vae = load_vae(args, cfg, device, vae_dtype)
|
1677 |
+
|
1678 |
+
# 3. Load DiT model if not already loaded
|
1679 |
+
if model is None:
|
1680 |
+
logger.info("Loading DiT model")
|
1681 |
+
model = load_dit_model(args, cfg, device, dit_dtype, dit_weight_dtype, is_i2v)
|
1682 |
+
|
1683 |
+
# Merge LoRA weights if needed
|
1684 |
+
if args.lora_weight is not None and len(args.lora_weight) > 0:
|
1685 |
+
merge_lora_weights(lora_wan, model, args, device)
|
1686 |
+
|
1687 |
+
# Optimize model
|
1688 |
+
optimize_model(model, args, device, dit_dtype, dit_weight_dtype)
|
1689 |
+
else:
|
1690 |
+
# Move model to GPU if it was offloaded
|
1691 |
+
model.to(device)
|
1692 |
+
|
1693 |
+
# Create shared models dict
|
1694 |
+
shared_models = {"vae": vae, "model": model, "encoded_contexts": {prompt_data["prompt"]: encoded_context}}
|
1695 |
+
|
1696 |
+
# Generate latent
|
1697 |
+
latent = generate(prompt_args, gen_settings, shared_models)
|
1698 |
+
|
1699 |
+
# Move model to CPU after generation
|
1700 |
+
model.to("cpu")
|
1701 |
+
|
1702 |
+
# Save latent if needed
|
1703 |
+
height, width, _ = check_inputs(prompt_args)
|
1704 |
+
if prompt_args.output_type == "latent" or prompt_args.output_type == "both":
|
1705 |
+
save_latent(latent, prompt_args, height, width)
|
1706 |
+
|
1707 |
+
# Decode and save output
|
1708 |
+
if prompt_args.output_type != "latent":
|
1709 |
+
if vae is None:
|
1710 |
+
vae = load_vae(args, cfg, device, vae_dtype)
|
1711 |
+
|
1712 |
+
vae.to_device(device)
|
1713 |
+
video = decode_latent(latent.unsqueeze(0), prompt_args, cfg)
|
1714 |
+
|
1715 |
+
if prompt_args.output_type == "video" or prompt_args.output_type == "both":
|
1716 |
+
save_video(video, prompt_args)
|
1717 |
+
elif prompt_args.output_type == "images":
|
1718 |
+
save_images(video, prompt_args)
|
1719 |
+
|
1720 |
+
# Move VAE to CPU after use
|
1721 |
+
vae.to_device("cpu")
|
1722 |
+
|
1723 |
+
clean_memory_on_device(device)
|
1724 |
+
|
1725 |
+
except KeyboardInterrupt:
|
1726 |
+
print("\nInterrupted. Continue (Ctrl+D or Ctrl+Z (Windows) to exit)")
|
1727 |
+
continue
|
1728 |
+
|
1729 |
+
except EOFError:
|
1730 |
+
print("\nExiting interactive mode")
|
1731 |
+
|
1732 |
+
# Clean up all models
|
1733 |
+
if text_encoder is not None:
|
1734 |
+
del text_encoder
|
1735 |
+
if clip is not None:
|
1736 |
+
del clip
|
1737 |
+
if vae is not None:
|
1738 |
+
del vae
|
1739 |
+
if model is not None:
|
1740 |
+
del model
|
1741 |
+
|
1742 |
+
clean_memory_on_device(device)
|
1743 |
+
gc.collect()
|
1744 |
+
|
1745 |
+
|
1746 |
+
def get_generation_settings(args: argparse.Namespace) -> GenerationSettings:
|
1747 |
+
device = torch.device(args.device)
|
1748 |
+
|
1749 |
+
cfg = WAN_CONFIGS[args.task]
|
1750 |
+
|
1751 |
+
# select dtype
|
1752 |
+
dit_dtype = detect_wan_sd_dtype(args.dit) if args.dit is not None else torch.bfloat16
|
1753 |
+
if dit_dtype.itemsize == 1:
|
1754 |
+
# if weight is in fp8, use bfloat16 for DiT (input/output)
|
1755 |
+
dit_dtype = torch.bfloat16
|
1756 |
+
if args.fp8_scaled:
|
1757 |
+
raise ValueError(
|
1758 |
+
"DiT weights is already in fp8 format, cannot scale to fp8. Please use fp16/bf16 weights / DiTの重みはすでにfp8形式です。fp8にスケーリングできません。fp16/bf16の重みを使用してください"
|
1759 |
+
)
|
1760 |
+
|
1761 |
+
dit_weight_dtype = dit_dtype # default
|
1762 |
+
if args.fp8_scaled:
|
1763 |
+
dit_weight_dtype = None # various precision weights, so don't cast to specific dtype
|
1764 |
+
elif args.fp8:
|
1765 |
+
dit_weight_dtype = torch.float8_e4m3fn
|
1766 |
+
|
1767 |
+
vae_dtype = str_to_dtype(args.vae_dtype) if args.vae_dtype is not None else dit_dtype
|
1768 |
+
logger.info(
|
1769 |
+
f"Using device: {device}, DiT precision: {dit_dtype}, weight precision: {dit_weight_dtype}, VAE precision: {vae_dtype}"
|
1770 |
+
)
|
1771 |
+
|
1772 |
+
gen_settings = GenerationSettings(
|
1773 |
+
device=device,
|
1774 |
+
cfg=cfg,
|
1775 |
+
dit_dtype=dit_dtype,
|
1776 |
+
dit_weight_dtype=dit_weight_dtype,
|
1777 |
+
vae_dtype=vae_dtype,
|
1778 |
+
)
|
1779 |
+
return gen_settings
|
1780 |
+
|
1781 |
+
|
1782 |
+
def main():
|
1783 |
+
# Parse arguments
|
1784 |
+
args = parse_args()
|
1785 |
+
|
1786 |
+
# Check if latents are provided
|
1787 |
+
latents_mode = args.latent_path is not None and len(args.latent_path) > 0
|
1788 |
+
|
1789 |
+
# Set device
|
1790 |
+
device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
|
1791 |
+
device = torch.device(device)
|
1792 |
+
logger.info(f"Using device: {device}")
|
1793 |
+
args.device = device
|
1794 |
+
|
1795 |
+
if latents_mode:
|
1796 |
+
# Original latent decode mode
|
1797 |
+
cfg = WAN_CONFIGS[args.task] # any task is fine
|
1798 |
+
original_base_names = []
|
1799 |
+
latents_list = []
|
1800 |
+
seeds = []
|
1801 |
+
|
1802 |
+
assert len(args.latent_path) == 1, "Only one latent path is supported for now"
|
1803 |
+
|
1804 |
+
for latent_path in args.latent_path:
|
1805 |
+
original_base_names.append(os.path.splitext(os.path.basename(latent_path))[0])
|
1806 |
+
seed = 0
|
1807 |
+
|
1808 |
+
if os.path.splitext(latent_path)[1] != ".safetensors":
|
1809 |
+
latents = torch.load(latent_path, map_location="cpu")
|
1810 |
+
else:
|
1811 |
+
latents = load_file(latent_path)["latent"]
|
1812 |
+
with safe_open(latent_path, framework="pt") as f:
|
1813 |
+
metadata = f.metadata()
|
1814 |
+
if metadata is None:
|
1815 |
+
metadata = {}
|
1816 |
+
logger.info(f"Loaded metadata: {metadata}")
|
1817 |
+
|
1818 |
+
if "seeds" in metadata:
|
1819 |
+
seed = int(metadata["seeds"])
|
1820 |
+
if "height" in metadata and "width" in metadata:
|
1821 |
+
height = int(metadata["height"])
|
1822 |
+
width = int(metadata["width"])
|
1823 |
+
args.video_size = [height, width]
|
1824 |
+
if "video_length" in metadata:
|
1825 |
+
args.video_length = int(metadata["video_length"])
|
1826 |
+
|
1827 |
+
seeds.append(seed)
|
1828 |
+
latents_list.append(latents)
|
1829 |
+
|
1830 |
+
logger.info(f"Loaded latent from {latent_path}. Shape: {latents.shape}")
|
1831 |
+
|
1832 |
+
latent = torch.stack(latents_list, dim=0) # [N, ...], must be same shape
|
1833 |
+
|
1834 |
+
height = latents.shape[-2]
|
1835 |
+
width = latents.shape[-1]
|
1836 |
+
height *= cfg.patch_size[1] * cfg.vae_stride[1]
|
1837 |
+
width *= cfg.patch_size[2] * cfg.vae_stride[2]
|
1838 |
+
video_length = latents.shape[1]
|
1839 |
+
video_length = (video_length - 1) * cfg.vae_stride[0] + 1
|
1840 |
+
args.seed = seeds[0]
|
1841 |
+
|
1842 |
+
# Decode and save
|
1843 |
+
save_output(latent[0], args, cfg, height, width, original_base_names)
|
1844 |
+
|
1845 |
+
elif args.from_file:
|
1846 |
+
# Batch mode from file
|
1847 |
+
args = setup_args(args)
|
1848 |
+
|
1849 |
+
# Read prompts from file
|
1850 |
+
with open(args.from_file, "r", encoding="utf-8") as f:
|
1851 |
+
prompt_lines = f.readlines()
|
1852 |
+
|
1853 |
+
# Process prompts
|
1854 |
+
prompts_data = preprocess_prompts_for_batch(prompt_lines, args)
|
1855 |
+
process_batch_prompts(prompts_data, args)
|
1856 |
+
|
1857 |
+
elif args.interactive:
|
1858 |
+
# Interactive mode
|
1859 |
+
args = setup_args(args)
|
1860 |
+
process_interactive(args)
|
1861 |
+
|
1862 |
+
else:
|
1863 |
+
# Single prompt mode (original behavior)
|
1864 |
+
args = setup_args(args)
|
1865 |
+
height, width, video_length = check_inputs(args)
|
1866 |
+
|
1867 |
+
logger.info(
|
1868 |
+
f"Video size: {height}x{width}@{video_length} (HxW@F), fps: {args.fps}, "
|
1869 |
+
f"infer_steps: {args.infer_steps}, flow_shift: {args.flow_shift}"
|
1870 |
+
)
|
1871 |
+
|
1872 |
+
# Generate latent
|
1873 |
+
gen_settings = get_generation_settings(args)
|
1874 |
+
latent = generate(args, gen_settings)
|
1875 |
+
|
1876 |
+
# Make sure the model is freed from GPU memory
|
1877 |
+
gc.collect()
|
1878 |
+
clean_memory_on_device(args.device)
|
1879 |
+
|
1880 |
+
# Save latent and video
|
1881 |
+
if args.save_merged_model:
|
1882 |
+
return
|
1883 |
+
|
1884 |
+
# Add batch dimension
|
1885 |
+
latent = latent.unsqueeze(0)
|
1886 |
+
save_output(latent[0], args, WAN_CONFIGS[args.task], height, width)
|
1887 |
+
|
1888 |
+
logger.info("Done!")
|
1889 |
+
|
1890 |
+
|
1891 |
+
if __name__ == "__main__":
|
1892 |
+
main()
|
blissful_tuner/GIMMVFI.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
Frame Rate Interpolation using GIMM-VFI
|
5 |
+
-----------------------------------
|
6 |
+
This specific code file as well as all files in ./blissful_tuner/gimmvfi and subfolders (all GIMM-VFI related code) licensed:
|
7 |
+
|
8 |
+
S-Lab License 1.0
|
9 |
+
Copyright 2024 S-Lab
|
10 |
+
|
11 |
+
Redistribution and use for non-commercial purpose in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
12 |
+
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
13 |
+
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
14 |
+
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
15 |
+
|
16 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
|
17 |
+
IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
18 |
+
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
19 |
+
|
20 |
+
In the event that redistribution and/or use for commercial purpose in source or binary forms, with or without modification is required, please contact the contributor(s) of the work.
|
21 |
+
---------------------------------------
|
22 |
+
Created on Mon Apr 14 12:23:15 2025
|
23 |
+
@author: blyss
|
24 |
+
"""
|
25 |
+
|
26 |
+
import os
|
27 |
+
import warnings
|
28 |
+
from typing import List
|
29 |
+
import torch
|
30 |
+
import yaml
|
31 |
+
from tqdm import tqdm
|
32 |
+
from omegaconf import OmegaConf
|
33 |
+
from rich.traceback import install as install_rich_tracebacks
|
34 |
+
|
35 |
+
# Importing necessary modules from our project
|
36 |
+
from gimmvfi.generalizable_INR.gimmvfi_r import GIMMVFI_R
|
37 |
+
from gimmvfi.generalizable_INR.gimmvfi_f import GIMMVFI_F
|
38 |
+
from gimmvfi.generalizable_INR.configs import GIMMVFIConfig
|
39 |
+
from gimmvfi.generalizable_INR.raft import RAFT
|
40 |
+
from gimmvfi.generalizable_INR.flowformer.core.FlowFormer.LatentCostFormer.transformer import FlowFormer
|
41 |
+
from gimmvfi.generalizable_INR.flowformer.configs.submission import get_cfg
|
42 |
+
from gimmvfi.utils.utils import InputPadder, RaftArgs, easydict_to_dict
|
43 |
+
from utils import load_torch_file, setup_compute_context
|
44 |
+
from video_processing_common import BlissfulVideoProcessor, setup_parser_video_common, set_seed
|
45 |
+
warnings.filterwarnings("ignore")
|
46 |
+
install_rich_tracebacks()
|
47 |
+
|
48 |
+
|
49 |
+
def load_model(model_path: str, device: torch.device, dtype: torch.dtype, mode: str = "gimmvfi_r") -> torch.nn.Module:
|
50 |
+
"""
|
51 |
+
Loads the GIMM-VFI model along with its required flow estimator.
|
52 |
+
|
53 |
+
Depending on the mode ("gimmvfi_r" or "gimmvfi_f") a different configuration,
|
54 |
+
checkpoint, and flow estimation network are loaded.
|
55 |
+
"""
|
56 |
+
|
57 |
+
# Select proper configuration, checkpoint, and flow model based on mode.
|
58 |
+
if "gimmvfi_r" in mode:
|
59 |
+
config_path = os.path.join(model_path, "gimmvfi_r_arb.yaml")
|
60 |
+
flow_model_filename = "raft-things_fp32.safetensors"
|
61 |
+
checkpoint = os.path.join(model_path, "gimmvfi_r_arb_lpips_fp32.safetensors")
|
62 |
+
elif "gimmvfi_f" in mode:
|
63 |
+
config_path = os.path.join(model_path, "gimmvfi_f_arb.yaml")
|
64 |
+
checkpoint = os.path.join(model_path, "gimmvfi_f_arb_lpips_fp32.safetensors")
|
65 |
+
flow_model_filename = "flowformer_sintel_fp32.safetensors"
|
66 |
+
else:
|
67 |
+
raise ValueError(f"Unsupported mode: {mode}")
|
68 |
+
|
69 |
+
flow_model_path = os.path.join(model_path, flow_model_filename)
|
70 |
+
|
71 |
+
# Load and merge YAML configuration
|
72 |
+
with open(config_path) as f:
|
73 |
+
config = yaml.load(f, Loader=yaml.FullLoader)
|
74 |
+
config = easydict_to_dict(config)
|
75 |
+
config = OmegaConf.create(config)
|
76 |
+
arch_defaults = GIMMVFIConfig.create(config.arch)
|
77 |
+
config = OmegaConf.merge(arch_defaults, config.arch)
|
78 |
+
|
79 |
+
# Initialize the model and its associated flow estimator
|
80 |
+
if "gimmvfi_r" in mode:
|
81 |
+
model = GIMMVFI_R(config)
|
82 |
+
# Setup RAFT as flow estimator
|
83 |
+
raft_args = RaftArgs(small=False, mixed_precision=False, alternate_corr=False)
|
84 |
+
raft_model = RAFT(raft_args)
|
85 |
+
raft_sd = load_torch_file(flow_model_path)
|
86 |
+
raft_model.load_state_dict(raft_sd, strict=True)
|
87 |
+
flow_estimator = raft_model.to(device, dtype)
|
88 |
+
else: # mode == "gimmvfi_f"
|
89 |
+
model = GIMMVFI_F(config)
|
90 |
+
cfg = get_cfg()
|
91 |
+
flowformer = FlowFormer(cfg.latentcostformer)
|
92 |
+
flowformer_sd = load_torch_file(flow_model_path)
|
93 |
+
flowformer.load_state_dict(flowformer_sd, strict=True)
|
94 |
+
flow_estimator = flowformer.to(device, dtype)
|
95 |
+
|
96 |
+
# Load main model checkpoint
|
97 |
+
sd = load_torch_file(checkpoint)
|
98 |
+
model.load_state_dict(sd, strict=False)
|
99 |
+
|
100 |
+
# Attach the flow estimator to the model, set evaluation mode, and move to device
|
101 |
+
model.flow_estimator = flow_estimator
|
102 |
+
model = model.eval().to(device, dtype)
|
103 |
+
|
104 |
+
return model
|
105 |
+
|
106 |
+
|
107 |
+
def interpolate(model: torch.nn.Module, frames: List[torch.Tensor], ds_factor: float, N: int, VideoProcessor: BlissfulVideoProcessor):
|
108 |
+
"""
|
109 |
+
Interpolates frames using the provided model.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
model: The loaded interpolation model.
|
113 |
+
frames: List of input frame tensors.
|
114 |
+
ds_factor: Downsampling factor used by the model.
|
115 |
+
N: Number of interpolation steps between two frames.
|
116 |
+
"""
|
117 |
+
device = VideoProcessor.device
|
118 |
+
dtype = VideoProcessor.dtype
|
119 |
+
start = 0
|
120 |
+
end = len(frames) - 1
|
121 |
+
|
122 |
+
# Process each adjacent pair of frames.
|
123 |
+
for j in tqdm(range(start, end), desc="Interpolating frames"):
|
124 |
+
I0 = frames[j]
|
125 |
+
I2 = frames[j + 1]
|
126 |
+
|
127 |
+
# For the very first frame, add it directly.
|
128 |
+
if j == start:
|
129 |
+
VideoProcessor.write_np_or_tensor_to_png(I0)
|
130 |
+
|
131 |
+
# Pad both images so that their dimensions are divisible by 32.
|
132 |
+
padder = InputPadder(I0.shape, 32)
|
133 |
+
I0_padded, I2_padded = padder.pad(I0, I2)
|
134 |
+
# Concatenate along a new dimension to create a tensor of shape [batch, 2, C, H, W]
|
135 |
+
xs = torch.cat((I0_padded.unsqueeze(2), I2_padded.unsqueeze(2)), dim=2).to(device, dtype, non_blocking=True)
|
136 |
+
|
137 |
+
model.zero_grad()
|
138 |
+
|
139 |
+
batch_size = xs.shape[0]
|
140 |
+
s_shape = xs.shape[-2:]
|
141 |
+
|
142 |
+
with torch.no_grad():
|
143 |
+
# Prepare coordinate inputs and timesteps for interpolation.
|
144 |
+
coord_inputs = [
|
145 |
+
(
|
146 |
+
model.sample_coord_input(
|
147 |
+
batch_size,
|
148 |
+
s_shape,
|
149 |
+
[1 / N * i],
|
150 |
+
device=xs.device,
|
151 |
+
upsample_ratio=ds_factor,
|
152 |
+
),
|
153 |
+
None,
|
154 |
+
)
|
155 |
+
for i in range(1, N)
|
156 |
+
]
|
157 |
+
timesteps = [
|
158 |
+
i / N * torch.ones(batch_size, device=xs.device, dtype=dtype)
|
159 |
+
for i in range(1, N)
|
160 |
+
]
|
161 |
+
if dtype != torch.float32:
|
162 |
+
with torch.autocast(device_type=str(device), dtype=dtype):
|
163 |
+
all_outputs = model(xs, coord_inputs, t=timesteps, ds_factor=ds_factor)
|
164 |
+
else:
|
165 |
+
all_outputs = model(xs, coord_inputs, t=timesteps, ds_factor=ds_factor)
|
166 |
+
# Unpad the outputs to get back to original image size.
|
167 |
+
out_frames = [padder.unpad(im) for im in all_outputs["imgt_pred"]]
|
168 |
+
|
169 |
+
# Convert each interpolated frame tensor to an image array.
|
170 |
+
I1_pred_images = [I1_pred[0] for I1_pred in out_frames]
|
171 |
+
|
172 |
+
# Append the interpolated frames and corresponding flow images.
|
173 |
+
for i in range(N - 1):
|
174 |
+
VideoProcessor.write_np_or_tensor_to_png(I1_pred_images[i])
|
175 |
+
|
176 |
+
# Append the next original frame.
|
177 |
+
VideoProcessor.write_np_or_tensor_to_png(I2)
|
178 |
+
|
179 |
+
|
180 |
+
def main():
|
181 |
+
parser = setup_parser_video_common(description="Frame rate interpolation using GIMM-VFI")
|
182 |
+
parser.add_argument("--ds_factor", type=float, default=1.0, help="Downsampling factor")
|
183 |
+
parser.add_argument("--mode", type=str, default="gimmvfi_f", help="Model mode: 'gimmvfi_r' or 'gimmvfi_f' for RAFT or FlowFormer version respectively")
|
184 |
+
parser.add_argument(
|
185 |
+
"--factor", type=int, default=2, help="Factor to increase the number of frames by. \
|
186 |
+
A factor of 2 will double the fps, taking e.g. a 16fps video to 32fps. Can go up to 8 but higher values have more artifacts"
|
187 |
+
)
|
188 |
+
args = parser.parse_args()
|
189 |
+
device, dtype = setup_compute_context(None, args.dtype)
|
190 |
+
VideoProcessor = BlissfulVideoProcessor(device, dtype)
|
191 |
+
VideoProcessor.prepare_files_and_path(args.input, args.output, "VFI", args.codec, args.container)
|
192 |
+
model = load_model(args.model, device, dtype, args.mode)
|
193 |
+
frames, fps, _, _ = VideoProcessor.load_frames(make_rgb=True)
|
194 |
+
frames = VideoProcessor.np_image_to_tensor(frames)
|
195 |
+
new_fps = fps * args.factor # Adjust the frame rate according to the interpolation
|
196 |
+
|
197 |
+
# Set seed for reproducibility.
|
198 |
+
set_seed(args.seed)
|
199 |
+
|
200 |
+
# Perform the frame interpolation.
|
201 |
+
interpolate(model, frames, args.ds_factor, args.factor, VideoProcessor)
|
202 |
+
|
203 |
+
# Save the interpolated video.
|
204 |
+
VideoProcessor.write_buffered_frames_to_output(new_fps, args.keep_pngs)
|
205 |
+
|
206 |
+
|
207 |
+
if __name__ == "__main__":
|
208 |
+
main()
|
blissful_tuner/__init__.py
ADDED
File without changes
|
blissful_tuner/advanced_rope.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
Created on Wed Apr 16 19:25:53 2025
|
5 |
+
Advanced rope functions for Blissful Tuner extension
|
6 |
+
License: Apache 2.0
|
7 |
+
|
8 |
+
@author: blyss
|
9 |
+
"""
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
from einops import rearrange
|
13 |
+
from typing import List
|
14 |
+
from blissful_tuner.hvw_posemb_layers import get_nd_rotary_pos_embed
|
15 |
+
|
16 |
+
|
17 |
+
# From ComfyUI
|
18 |
+
def apply_rope_comfy(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
19 |
+
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
20 |
+
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
21 |
+
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
22 |
+
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
23 |
+
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
24 |
+
|
25 |
+
|
26 |
+
# From WanVideoWrapper
|
27 |
+
def rope_riflex(pos, dim, theta, L_test, k, temporal):
|
28 |
+
assert dim % 2 == 0
|
29 |
+
device = pos.device
|
30 |
+
scale = torch.linspace(0, (dim - 2) / dim, steps=dim // 2, dtype=torch.float64, device=device)
|
31 |
+
omega = 1.0 / (theta**scale)
|
32 |
+
# RIFLEX modification - adjust last frequency component if L_test and k are provided
|
33 |
+
if temporal and k > 0 and L_test:
|
34 |
+
omega[k - 1] = 0.9 * 2 * torch.pi / L_test
|
35 |
+
out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega)
|
36 |
+
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
37 |
+
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
38 |
+
return out.to(dtype=torch.float32, device=pos.device)
|
39 |
+
|
40 |
+
|
41 |
+
class EmbedND_RifleX(nn.Module):
|
42 |
+
def __init__(self: nn.Module, dim: int, theta: float, axes_dim: List[int], num_frames: int, k: int):
|
43 |
+
super().__init__()
|
44 |
+
self.dim = dim
|
45 |
+
self.theta = theta
|
46 |
+
self.axes_dim = axes_dim
|
47 |
+
self.num_frames = num_frames
|
48 |
+
self.k = k
|
49 |
+
|
50 |
+
def forward(self, ids):
|
51 |
+
n_axes = ids.shape[-1]
|
52 |
+
emb = torch.cat(
|
53 |
+
[rope_riflex(ids[..., i], self.axes_dim[i], self.theta, self.num_frames, self.k, temporal=True if i == 0 else False) for i in range(n_axes)],
|
54 |
+
dim=-3,
|
55 |
+
)
|
56 |
+
return emb.unsqueeze(1)
|
57 |
+
|
58 |
+
|
59 |
+
# Modified from HunyuanVideo Wrapper
|
60 |
+
def get_rotary_pos_embed_riflex(vae_ver, transformer, latent_video_length, height, width, k=0):
|
61 |
+
if "884" in vae_ver:
|
62 |
+
latents_size = [(latent_video_length - 1) // 4 + 1, height // 8, width // 8]
|
63 |
+
elif "888" in vae_ver:
|
64 |
+
latents_size = [(latent_video_length - 1) // 8 + 1, height // 8, width // 8]
|
65 |
+
else:
|
66 |
+
latents_size = [latent_video_length, height // 8, width // 8]
|
67 |
+
|
68 |
+
target_ndim = 3
|
69 |
+
ndim = 5 - 2
|
70 |
+
rope_theta = 256 # 225
|
71 |
+
patch_size = transformer.patch_size
|
72 |
+
rope_dim_list = transformer.rope_dim_list
|
73 |
+
hidden_size = transformer.hidden_size
|
74 |
+
heads_num = transformer.heads_num
|
75 |
+
head_dim = hidden_size // heads_num
|
76 |
+
|
77 |
+
if isinstance(patch_size, int):
|
78 |
+
assert all(s % patch_size == 0 for s in latents_size), (
|
79 |
+
f"Latent size(last {ndim} dimensions) should be divisible by patch size({patch_size}), "
|
80 |
+
f"but got {latents_size}."
|
81 |
+
)
|
82 |
+
rope_sizes = [s // patch_size for s in latents_size]
|
83 |
+
elif isinstance(patch_size, list):
|
84 |
+
assert all(
|
85 |
+
s % patch_size[idx] == 0
|
86 |
+
for idx, s in enumerate(latents_size)
|
87 |
+
), (
|
88 |
+
f"Latent size(last {ndim} dimensions) should be divisible by patch size({patch_size}), "
|
89 |
+
f"but got {latents_size}."
|
90 |
+
)
|
91 |
+
rope_sizes = [
|
92 |
+
s // patch_size[idx] for idx, s in enumerate(latents_size)
|
93 |
+
]
|
94 |
+
|
95 |
+
if len(rope_sizes) != target_ndim:
|
96 |
+
rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis
|
97 |
+
|
98 |
+
if rope_dim_list is None:
|
99 |
+
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
|
100 |
+
assert (
|
101 |
+
sum(rope_dim_list) == head_dim
|
102 |
+
), "sum(rope_dim_list) should equal to head_dim of attention layer"
|
103 |
+
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
|
104 |
+
rope_dim_list,
|
105 |
+
rope_sizes,
|
106 |
+
theta=rope_theta,
|
107 |
+
use_real=True,
|
108 |
+
theta_rescale_factor=1,
|
109 |
+
num_frames=latent_video_length,
|
110 |
+
k=k,
|
111 |
+
)
|
112 |
+
return freqs_cos, freqs_sin
|
blissful_tuner/blissful_args.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
Created on Sat Apr 26 15:11:58 2025
|
5 |
+
|
6 |
+
@author: blyss
|
7 |
+
"""
|
8 |
+
import sys
|
9 |
+
import os
|
10 |
+
import argparse
|
11 |
+
import torch
|
12 |
+
from rich.traceback import install as install_rich_tracebacks
|
13 |
+
from blissful_tuner.utils import BlissfulLogger, string_to_seed, parse_scheduled_cfg, error_out
|
14 |
+
logger = BlissfulLogger(__name__, "#8e00ed")
|
15 |
+
|
16 |
+
BLISSFUL_VERSION = "0.4.0"
|
17 |
+
|
18 |
+
CFG_SCHEDULE_HELP = """
|
19 |
+
Comma-separated list of steps/ranges where CFG should be applied.
|
20 |
+
|
21 |
+
You can specify:
|
22 |
+
- Single steps (e.g., '5')
|
23 |
+
- Ranges (e.g., '1-10')
|
24 |
+
- Modulus patterns (e.g., 'e~2' for every 2 steps)
|
25 |
+
- Guidance scale overrides (e.g., '1-10:5.0')
|
26 |
+
|
27 |
+
Example schedule:
|
28 |
+
'e~2:6.4, 1-10, 46-50'
|
29 |
+
|
30 |
+
This would apply:
|
31 |
+
- Default CFG scale for steps 1-10 and 46-50
|
32 |
+
- 6.4 CFG scale every 2 steps outside that range
|
33 |
+
- No CFG otherwise
|
34 |
+
|
35 |
+
You can exclude steps using '!', e.g., '!32' skips step 32.
|
36 |
+
Note: The list is processed left to right, so modulus ranges should come first and exclusions at the end!
|
37 |
+
"""
|
38 |
+
|
39 |
+
ROOT_SCRIPT = os.path.basename(sys.argv[0]).lower()
|
40 |
+
if "hv_" in ROOT_SCRIPT:
|
41 |
+
DIFFUSION_MODEL = "hunyuan"
|
42 |
+
elif "wan_" in ROOT_SCRIPT:
|
43 |
+
DIFFUSION_MODEL = "wan"
|
44 |
+
elif "fpack_" in ROOT_SCRIPT:
|
45 |
+
DIFFUSION_MODEL = "framepack"
|
46 |
+
else:
|
47 |
+
raise ValueError("Unsupported root_script for Blissful Extension")
|
48 |
+
|
49 |
+
if "generate" in ROOT_SCRIPT:
|
50 |
+
MODE = "generate"
|
51 |
+
elif "train" in ROOT_SCRIPT:
|
52 |
+
MODE = "train"
|
53 |
+
else:
|
54 |
+
raise ValueError("Unsupported root script for Blissful Extension!")
|
55 |
+
|
56 |
+
|
57 |
+
def blissful_prefunc(args: argparse.Namespace):
|
58 |
+
"""Simple function to print about version, environment, and things"""
|
59 |
+
cuda_list = [f"PyTorch: {torch.__version__}"]
|
60 |
+
if torch.cuda.is_available():
|
61 |
+
allocator = torch.cuda.get_allocator_backend()
|
62 |
+
cuda = torch.cuda.get_device_properties(0)
|
63 |
+
cuda_list[0] += f", CUDA: {torch.version.cuda} CC: {cuda.major}.{cuda.minor}"
|
64 |
+
cuda_list.append(f"Device: '{cuda.name}', VRAM: '{cuda.total_memory // 1024 ** 2}MB'")
|
65 |
+
for string in cuda_list:
|
66 |
+
logger.info(string)
|
67 |
+
if args.fp16_accumulation and MODE == "generate":
|
68 |
+
logger.info("Enabling FP16 accumulation")
|
69 |
+
if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"):
|
70 |
+
torch.backends.cuda.matmul.allow_fp16_accumulation = True
|
71 |
+
else:
|
72 |
+
raise ValueError("torch.backends.cuda.matmul.allow_fp16_accumulation is not available in this version of torch, requires torch 2.7.0.dev2025 02 26 nightly minimum")
|
73 |
+
|
74 |
+
|
75 |
+
def add_blissful_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
76 |
+
install_rich_tracebacks()
|
77 |
+
if DIFFUSION_MODEL == "wan":
|
78 |
+
parser.add_argument("--noise_aug_strength", type=float, default=0.0, help="Additional multiplier for i2v noise, higher might help motion/quality")
|
79 |
+
parser.add_argument("--prompt_weighting", action="store_true", help="Enable (prompt weighting:1.2)")
|
80 |
+
parser.add_argument(
|
81 |
+
"--rope_func", type=str, default="default",
|
82 |
+
help="Function to use for ROPE. Choose from 'default' or 'comfy' the latter of which uses ComfyUI implementation and is compilable with torch.compile to enable BIG VRAM savings"
|
83 |
+
)
|
84 |
+
|
85 |
+
elif DIFFUSION_MODEL == "hunyuan":
|
86 |
+
parser.add_argument("--hidden_state_skip_layer", type=int, default=2, help="Hidden state skip layer for LLM. Default is 2. Think 'clip skip' for the LLM")
|
87 |
+
parser.add_argument("--apply_final_norm", type=bool, default=False, help="Apply final norm for LLM. Default is False. Usually makes things worse.")
|
88 |
+
parser.add_argument("--reproduce", action="store_true", help="Enable reproducible output(Same seed = same result. Default is False.")
|
89 |
+
parser.add_argument("--fp8_scaled", action="store_true", help="Scaled FP8 quantization. Better quality/accuracy with slightly more VRAM usage.")
|
90 |
+
parser.add_argument("--prompt_2", type=str, required=False, help="Optional different prompt for CLIP")
|
91 |
+
parser.add_argument("--te_multiplier", nargs=2, metavar=("llm_multiplier", "clip_multiplier"), help="Scale clip and llm influence")
|
92 |
+
elif DIFFUSION_MODEL == "framepack":
|
93 |
+
parser.add_argument("--preview_latent_every", type=int, default=None, help="Enable latent preview every N sections. If --preview_vae is not specified it will use latent2rgb")
|
94 |
+
|
95 |
+
if DIFFUSION_MODEL in ["wan", "hunyuan"]:
|
96 |
+
parser.add_argument("--riflex_index", type=int, default=0, help="Frequency for RifleX extension. 4 is good for Hunyuan, 6 is good for Wan. Only 'comfy' rope_func supports this with Wan!")
|
97 |
+
parser.add_argument("--cfgzerostar_scaling", action="store_true", help="Enables CFG-Zero* scaling - https://github.com/WeichenFan/CFG-Zero-star")
|
98 |
+
parser.add_argument("--cfgzerostar_init_steps", type=int, default=-1, help="Enables CFGZero* zeroing out the first N steps. 2 is good for Wan T2V, 1 for I2V")
|
99 |
+
parser.add_argument("--preview_latent_every", type=int, default=None, help="Enable latent preview every N steps. If --preview_vae is not specified it will use latent2rgb")
|
100 |
+
|
101 |
+
# Common
|
102 |
+
|
103 |
+
parser.add_argument("--preview_vae", type=str, help="Path to TAE vae for taehv previews")
|
104 |
+
parser.add_argument("--cfg_schedule", type=str, help=CFG_SCHEDULE_HELP)
|
105 |
+
parser.add_argument("--keep_pngs", action="store_true", help="Save frames as PNGs in addition to output video")
|
106 |
+
parser.add_argument("--codec", choices=["prores", "h264", "h265"], default=None, help="Codec to use, choose from 'prores', 'h264', or 'h265'")
|
107 |
+
parser.add_argument("--container", choices=["mkv", "mp4"], default="mkv", help="Container format to use, choose from 'mkv' or 'mp4'. Note prores can only go in MKV!")
|
108 |
+
parser.add_argument("--fp16_accumulation", action="store_true", help="Enable full FP16 Accmumulation in FP16 GEMMs, requires Pytorch 2.7.0 or higher")
|
109 |
+
return parser
|
110 |
+
|
111 |
+
|
112 |
+
def parse_blissful_args(args: argparse.Namespace) -> argparse.Namespace:
|
113 |
+
if args.seed is not None:
|
114 |
+
try:
|
115 |
+
args.seed = int(args.seed)
|
116 |
+
except ValueError:
|
117 |
+
string_seed = args.seed
|
118 |
+
args.seed = string_to_seed(args.seed)
|
119 |
+
logger.info(f"Seed {args.seed} was generated from string '{string_seed}'!")
|
120 |
+
if DIFFUSION_MODEL == "wan":
|
121 |
+
if args.riflex_index != 0 and args.rope_func.lower() != "comfy":
|
122 |
+
logger.error("RIFLEx can only be used with rope_func == 'comfy'!")
|
123 |
+
raise ValueError("RIFLEx can only be used with rope_func =='comfy'!")
|
124 |
+
if DIFFUSION_MODEL in ["wan", "hunyuan"]:
|
125 |
+
if args.cfg_schedule:
|
126 |
+
args.cfg_schedule = parse_scheduled_cfg(args.cfg_schedule, args.infer_steps, args.guidance_scale)
|
127 |
+
if args.cfgzerostar_scaling or args.cfgzerostar_init_steps != -1:
|
128 |
+
if args.guidance_scale == 1 and not args.cfg_schedule:
|
129 |
+
error_out(AttributeError, "Requested CFGZero* but CFG is not enabled so it will have no effect!")
|
130 |
+
blissful_prefunc(args)
|
131 |
+
return args
|
blissful_tuner/blissful_settings.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
Created on Tue Mar 11 19:08:55 2025
|
5 |
+
|
6 |
+
@author: blyss
|
7 |
+
"""
|
8 |
+
import os
|
9 |
+
import json
|
10 |
+
|
11 |
+
|
12 |
+
class SingletonMeta(type):
|
13 |
+
"""
|
14 |
+
The SingletonMeta class is useful for creating objects that persist as a single instance across the whole program. Basically a global class.
|
15 |
+
"""
|
16 |
+
_instances = {}
|
17 |
+
|
18 |
+
def __call__(cls, *Parameters, **kwParameters):
|
19 |
+
if cls not in cls._instances:
|
20 |
+
cls._instances[cls] = super(SingletonMeta, cls).__call__(*Parameters, **kwParameters)
|
21 |
+
return cls._instances[cls]
|
22 |
+
|
23 |
+
|
24 |
+
class BlissfulSettings(metaclass=SingletonMeta):
|
25 |
+
def __init__(self):
|
26 |
+
"""
|
27 |
+
Loads the settings from a 'settings.json' file, creating it with default settings if it doesn't exist.
|
28 |
+
|
29 |
+
This method attempts to read the program's settings from a JSON file. If the file does not exist,
|
30 |
+
it creates a new file with default settings. This ensures that the program can start with a known
|
31 |
+
set of configurations and modify them as needed.
|
32 |
+
|
33 |
+
This class is a SingletonMeta so even if we reinstantiate the class, this only happens the first time
|
34 |
+
"""
|
35 |
+
# These are globals that do not persist
|
36 |
+
self.generating = 0
|
37 |
+
self.last_preview_file = ""
|
38 |
+
|
39 |
+
default_settings = {
|
40 |
+
"prompt": "a cat walks on the grass, realistic style",
|
41 |
+
"resolution_x": 960,
|
42 |
+
"resolution_y": 544,
|
43 |
+
"fps": 24,
|
44 |
+
"embedded_guidance": 6.0,
|
45 |
+
"flow_shift": 7.0,
|
46 |
+
"infer_steps": 50,
|
47 |
+
"seed": 42,
|
48 |
+
"video_length": 129,
|
49 |
+
"attention": "sage",
|
50 |
+
"blocks_to_swap": 0,
|
51 |
+
"hidden_state_skip_layer": 2,
|
52 |
+
"apply_final_norm": False,
|
53 |
+
"reproduce": False,
|
54 |
+
"fp8": True,
|
55 |
+
"fp8_fast": False,
|
56 |
+
"do_compile": False,
|
57 |
+
"transformer_path": "",
|
58 |
+
"text_encoder_1_path": "",
|
59 |
+
"text_encoder_2_path": "",
|
60 |
+
"vae_path": "",
|
61 |
+
"lora_path": "",
|
62 |
+
}
|
63 |
+
|
64 |
+
if not os.path.exists("./settings.json"):
|
65 |
+
with open("./settings.json", "w", encoding="utf-8") as file:
|
66 |
+
json.dump(default_settings, file, indent=4)
|
67 |
+
print("No existing settings found. Created default settings file.")
|
68 |
+
|
69 |
+
with open("./settings.json", "r", encoding="utf-8") as file:
|
70 |
+
data = json.load(file)
|
71 |
+
|
72 |
+
for key, default_value in default_settings.items():
|
73 |
+
setattr(self, key, data.get(key, default_value))
|
74 |
+
|
75 |
+
def save_to_file(self):
|
76 |
+
"""
|
77 |
+
Saves the current settings to a JSON file named 'settings.json'.
|
78 |
+
"""
|
79 |
+
settings = {
|
80 |
+
"prompt": self.prompt,
|
81 |
+
"resolution_x": self.resolution_x,
|
82 |
+
"resolution_y": self.resolution_y,
|
83 |
+
"fps": self.fps,
|
84 |
+
"embedded_guidance": self.embedded_guidance,
|
85 |
+
"flow_shift": self.flow_shift,
|
86 |
+
"infer_steps": self.infer_steps,
|
87 |
+
"seed": self.seed,
|
88 |
+
"video_length": self.video_length,
|
89 |
+
"attention": self.attention,
|
90 |
+
"blocks_to_swap": self.blocks_to_swap,
|
91 |
+
"hidden_state_skip_layer": self.hidden_state_skip_layer,
|
92 |
+
"apply_final_norm": self.apply_final_norm,
|
93 |
+
"reproduce": self.reproduce,
|
94 |
+
"fp8": self.fp8,
|
95 |
+
"fp8_fast": self.fp8_fast,
|
96 |
+
"do_compile": self.do_compile,
|
97 |
+
"transformer_path": self.transformer_path,
|
98 |
+
"text_encoder_1_path": self.text_encoder_1_path,
|
99 |
+
"text_encoder_2_path": self.text_encoder_2_path,
|
100 |
+
"vae_path": self.vae_path,
|
101 |
+
"lora_path": self.lora_path,
|
102 |
+
}
|
103 |
+
|
104 |
+
with open("./settings.json", "w", encoding="utf-8") as file:
|
105 |
+
json.dump(settings, file, indent=4)
|
106 |
+
|
107 |
+
def update(self, option, value, label_target=None, label_value=None):
|
108 |
+
"""Method for updating various settings called via QT connection and may update an associated label/value"""
|
109 |
+
setattr(self, option, value)
|
110 |
+
if label_target is not None and label_value is not None:
|
111 |
+
label_target.setText(str(label_value))
|
blissful_tuner/cfgzerostar.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
Created on Wed Apr 16 17:23:28 2025
|
5 |
+
CFGZero* implementation for Blissful Tuner extension based on https://github.com/WeichenFan/CFG-Zero-star/blob/main/models/wan/wan_pipeline.py
|
6 |
+
License: Apache 2.0
|
7 |
+
@author: blyss
|
8 |
+
"""
|
9 |
+
import torch
|
10 |
+
|
11 |
+
|
12 |
+
def apply_zerostar(cond: torch.Tensor, uncond: torch.Tensor, current_step: int, guidance_scale: float, use_scaling: bool = True, zero_init_steps: int = -1) -> torch.Tensor:
|
13 |
+
|
14 |
+
if (current_step <= zero_init_steps):
|
15 |
+
return cond * 0
|
16 |
+
if not use_scaling:
|
17 |
+
# CFG formula
|
18 |
+
noise_pred = uncond + guidance_scale * (cond - uncond)
|
19 |
+
else:
|
20 |
+
batch_size = cond.shape[0]
|
21 |
+
positive_flat = cond.view(batch_size, -1)
|
22 |
+
negative_flat = uncond.view(batch_size, -1)
|
23 |
+
alpha = optimized_scale(positive_flat, negative_flat)
|
24 |
+
alpha = alpha.view(batch_size, *([1] * (len(cond.shape) - 1)))
|
25 |
+
alpha = alpha.to(cond.dtype)
|
26 |
+
# CFG formula modified with alpha
|
27 |
+
noise_pred = uncond * alpha + guidance_scale * (cond - uncond * alpha)
|
28 |
+
return noise_pred
|
29 |
+
|
30 |
+
|
31 |
+
def optimized_scale(positive_flat, negative_flat):
|
32 |
+
|
33 |
+
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
|
34 |
+
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
|
35 |
+
|
36 |
+
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
|
37 |
+
st_star = dot_product / squared_norm
|
38 |
+
|
39 |
+
return st_star
|
blissful_tuner/codeformer/LICENSE
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
THIS FOLDER AND SUBFOLDERS (all CodeFormer related code and files) LICENSED AS BELOW
|
2 |
+
|
3 |
+
S-Lab License 1.0
|
4 |
+
Copyright 2024 S-Lab
|
5 |
+
|
6 |
+
Redistribution and use for non-commercial purpose in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
7 |
+
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
8 |
+
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
9 |
+
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
10 |
+
|
11 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
|
12 |
+
IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
13 |
+
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
14 |
+
|
15 |
+
In the event that redistribution and/or use for commercial purpose in source or binary forms, with or without modification is required, please contact the contributor(s) of the work.
|
blissful_tuner/codeformer/basicsr/VERSION
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
1.3.2
|
blissful_tuner/codeformer/basicsr/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://github.com/xinntao/BasicSR
|
2 |
+
# flake8: noqa
|
3 |
+
from .archs import *
|
4 |
+
from .data import *
|
5 |
+
from .losses import *
|
6 |
+
from .metrics import *
|
7 |
+
from .models import *
|
8 |
+
from .ops import *
|
9 |
+
from .train import *
|
10 |
+
from .utils import *
|
11 |
+
#from .version import __gitsha__, __version__
|
blissful_tuner/codeformer/basicsr/archs/__init__.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
from copy import deepcopy
|
3 |
+
from os import path as osp
|
4 |
+
|
5 |
+
from codeformer.basicsr.utils import get_root_logger, scandir
|
6 |
+
from codeformer.basicsr.utils.registry import ARCH_REGISTRY
|
7 |
+
|
8 |
+
__all__ = ['build_network']
|
9 |
+
|
10 |
+
# automatically scan and import arch modules for registry
|
11 |
+
# scan all the files under the 'archs' folder and collect files ending with
|
12 |
+
# '_arch.py'
|
13 |
+
arch_folder = osp.dirname(osp.abspath(__file__))
|
14 |
+
arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
|
15 |
+
# import all the arch modules
|
16 |
+
_arch_modules = [importlib.import_module(f'codeformer.basicsr.archs.{file_name}') for file_name in arch_filenames]
|
17 |
+
|
18 |
+
|
19 |
+
def build_network(opt):
|
20 |
+
opt = deepcopy(opt)
|
21 |
+
network_type = opt.pop('type')
|
22 |
+
net = ARCH_REGISTRY.get(network_type)(**opt)
|
23 |
+
logger = get_root_logger()
|
24 |
+
logger.info(f'Network [{net.__class__.__name__}] is created.')
|
25 |
+
return net
|
blissful_tuner/codeformer/basicsr/archs/arcface_arch.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from codeformer.basicsr.utils.registry import ARCH_REGISTRY
|
3 |
+
|
4 |
+
|
5 |
+
def conv3x3(inplanes, outplanes, stride=1):
|
6 |
+
"""A simple wrapper for 3x3 convolution with padding.
|
7 |
+
|
8 |
+
Args:
|
9 |
+
inplanes (int): Channel number of inputs.
|
10 |
+
outplanes (int): Channel number of outputs.
|
11 |
+
stride (int): Stride in convolution. Default: 1.
|
12 |
+
"""
|
13 |
+
return nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False)
|
14 |
+
|
15 |
+
|
16 |
+
class BasicBlock(nn.Module):
|
17 |
+
"""Basic residual block used in the ResNetArcFace architecture.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
inplanes (int): Channel number of inputs.
|
21 |
+
planes (int): Channel number of outputs.
|
22 |
+
stride (int): Stride in convolution. Default: 1.
|
23 |
+
downsample (nn.Module): The downsample module. Default: None.
|
24 |
+
"""
|
25 |
+
expansion = 1 # output channel expansion ratio
|
26 |
+
|
27 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
28 |
+
super(BasicBlock, self).__init__()
|
29 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
30 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
31 |
+
self.relu = nn.ReLU(inplace=True)
|
32 |
+
self.conv2 = conv3x3(planes, planes)
|
33 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
34 |
+
self.downsample = downsample
|
35 |
+
self.stride = stride
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
residual = x
|
39 |
+
|
40 |
+
out = self.conv1(x)
|
41 |
+
out = self.bn1(out)
|
42 |
+
out = self.relu(out)
|
43 |
+
|
44 |
+
out = self.conv2(out)
|
45 |
+
out = self.bn2(out)
|
46 |
+
|
47 |
+
if self.downsample is not None:
|
48 |
+
residual = self.downsample(x)
|
49 |
+
|
50 |
+
out += residual
|
51 |
+
out = self.relu(out)
|
52 |
+
|
53 |
+
return out
|
54 |
+
|
55 |
+
|
56 |
+
class IRBlock(nn.Module):
|
57 |
+
"""Improved residual block (IR Block) used in the ResNetArcFace architecture.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
inplanes (int): Channel number of inputs.
|
61 |
+
planes (int): Channel number of outputs.
|
62 |
+
stride (int): Stride in convolution. Default: 1.
|
63 |
+
downsample (nn.Module): The downsample module. Default: None.
|
64 |
+
use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
|
65 |
+
"""
|
66 |
+
expansion = 1 # output channel expansion ratio
|
67 |
+
|
68 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
|
69 |
+
super(IRBlock, self).__init__()
|
70 |
+
self.bn0 = nn.BatchNorm2d(inplanes)
|
71 |
+
self.conv1 = conv3x3(inplanes, inplanes)
|
72 |
+
self.bn1 = nn.BatchNorm2d(inplanes)
|
73 |
+
self.prelu = nn.PReLU()
|
74 |
+
self.conv2 = conv3x3(inplanes, planes, stride)
|
75 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
76 |
+
self.downsample = downsample
|
77 |
+
self.stride = stride
|
78 |
+
self.use_se = use_se
|
79 |
+
if self.use_se:
|
80 |
+
self.se = SEBlock(planes)
|
81 |
+
|
82 |
+
def forward(self, x):
|
83 |
+
residual = x
|
84 |
+
out = self.bn0(x)
|
85 |
+
out = self.conv1(out)
|
86 |
+
out = self.bn1(out)
|
87 |
+
out = self.prelu(out)
|
88 |
+
|
89 |
+
out = self.conv2(out)
|
90 |
+
out = self.bn2(out)
|
91 |
+
if self.use_se:
|
92 |
+
out = self.se(out)
|
93 |
+
|
94 |
+
if self.downsample is not None:
|
95 |
+
residual = self.downsample(x)
|
96 |
+
|
97 |
+
out += residual
|
98 |
+
out = self.prelu(out)
|
99 |
+
|
100 |
+
return out
|
101 |
+
|
102 |
+
|
103 |
+
class Bottleneck(nn.Module):
|
104 |
+
"""Bottleneck block used in the ResNetArcFace architecture.
|
105 |
+
|
106 |
+
Args:
|
107 |
+
inplanes (int): Channel number of inputs.
|
108 |
+
planes (int): Channel number of outputs.
|
109 |
+
stride (int): Stride in convolution. Default: 1.
|
110 |
+
downsample (nn.Module): The downsample module. Default: None.
|
111 |
+
"""
|
112 |
+
expansion = 4 # output channel expansion ratio
|
113 |
+
|
114 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
115 |
+
super(Bottleneck, self).__init__()
|
116 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
117 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
118 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
119 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
120 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
|
121 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
122 |
+
self.relu = nn.ReLU(inplace=True)
|
123 |
+
self.downsample = downsample
|
124 |
+
self.stride = stride
|
125 |
+
|
126 |
+
def forward(self, x):
|
127 |
+
residual = x
|
128 |
+
|
129 |
+
out = self.conv1(x)
|
130 |
+
out = self.bn1(out)
|
131 |
+
out = self.relu(out)
|
132 |
+
|
133 |
+
out = self.conv2(out)
|
134 |
+
out = self.bn2(out)
|
135 |
+
out = self.relu(out)
|
136 |
+
|
137 |
+
out = self.conv3(out)
|
138 |
+
out = self.bn3(out)
|
139 |
+
|
140 |
+
if self.downsample is not None:
|
141 |
+
residual = self.downsample(x)
|
142 |
+
|
143 |
+
out += residual
|
144 |
+
out = self.relu(out)
|
145 |
+
|
146 |
+
return out
|
147 |
+
|
148 |
+
|
149 |
+
class SEBlock(nn.Module):
|
150 |
+
"""The squeeze-and-excitation block (SEBlock) used in the IRBlock.
|
151 |
+
|
152 |
+
Args:
|
153 |
+
channel (int): Channel number of inputs.
|
154 |
+
reduction (int): Channel reduction ration. Default: 16.
|
155 |
+
"""
|
156 |
+
|
157 |
+
def __init__(self, channel, reduction=16):
|
158 |
+
super(SEBlock, self).__init__()
|
159 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1) # pool to 1x1 without spatial information
|
160 |
+
self.fc = nn.Sequential(
|
161 |
+
nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel),
|
162 |
+
nn.Sigmoid())
|
163 |
+
|
164 |
+
def forward(self, x):
|
165 |
+
b, c, _, _ = x.size()
|
166 |
+
y = self.avg_pool(x).view(b, c)
|
167 |
+
y = self.fc(y).view(b, c, 1, 1)
|
168 |
+
return x * y
|
169 |
+
|
170 |
+
|
171 |
+
@ARCH_REGISTRY.register()
|
172 |
+
class ResNetArcFace(nn.Module):
|
173 |
+
"""ArcFace with ResNet architectures.
|
174 |
+
|
175 |
+
Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition.
|
176 |
+
|
177 |
+
Args:
|
178 |
+
block (str): Block used in the ArcFace architecture.
|
179 |
+
layers (tuple(int)): Block numbers in each layer.
|
180 |
+
use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
|
181 |
+
"""
|
182 |
+
|
183 |
+
def __init__(self, block, layers, use_se=True):
|
184 |
+
if block == 'IRBlock':
|
185 |
+
block = IRBlock
|
186 |
+
self.inplanes = 64
|
187 |
+
self.use_se = use_se
|
188 |
+
super(ResNetArcFace, self).__init__()
|
189 |
+
|
190 |
+
self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
|
191 |
+
self.bn1 = nn.BatchNorm2d(64)
|
192 |
+
self.prelu = nn.PReLU()
|
193 |
+
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
|
194 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
195 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
196 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
197 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
198 |
+
self.bn4 = nn.BatchNorm2d(512)
|
199 |
+
self.dropout = nn.Dropout()
|
200 |
+
self.fc5 = nn.Linear(512 * 8 * 8, 512)
|
201 |
+
self.bn5 = nn.BatchNorm1d(512)
|
202 |
+
|
203 |
+
# initialization
|
204 |
+
for m in self.modules():
|
205 |
+
if isinstance(m, nn.Conv2d):
|
206 |
+
nn.init.xavier_normal_(m.weight)
|
207 |
+
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
|
208 |
+
nn.init.constant_(m.weight, 1)
|
209 |
+
nn.init.constant_(m.bias, 0)
|
210 |
+
elif isinstance(m, nn.Linear):
|
211 |
+
nn.init.xavier_normal_(m.weight)
|
212 |
+
nn.init.constant_(m.bias, 0)
|
213 |
+
|
214 |
+
def _make_layer(self, block, planes, num_blocks, stride=1):
|
215 |
+
downsample = None
|
216 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
217 |
+
downsample = nn.Sequential(
|
218 |
+
nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
|
219 |
+
nn.BatchNorm2d(planes * block.expansion),
|
220 |
+
)
|
221 |
+
layers = []
|
222 |
+
layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se))
|
223 |
+
self.inplanes = planes
|
224 |
+
for _ in range(1, num_blocks):
|
225 |
+
layers.append(block(self.inplanes, planes, use_se=self.use_se))
|
226 |
+
|
227 |
+
return nn.Sequential(*layers)
|
228 |
+
|
229 |
+
def forward(self, x):
|
230 |
+
x = self.conv1(x)
|
231 |
+
x = self.bn1(x)
|
232 |
+
x = self.prelu(x)
|
233 |
+
x = self.maxpool(x)
|
234 |
+
|
235 |
+
x = self.layer1(x)
|
236 |
+
x = self.layer2(x)
|
237 |
+
x = self.layer3(x)
|
238 |
+
x = self.layer4(x)
|
239 |
+
x = self.bn4(x)
|
240 |
+
x = self.dropout(x)
|
241 |
+
x = x.view(x.size(0), -1)
|
242 |
+
x = self.fc5(x)
|
243 |
+
x = self.bn5(x)
|
244 |
+
|
245 |
+
return x
|
blissful_tuner/codeformer/basicsr/archs/arch_util.py
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import collections.abc
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
import torchvision
|
5 |
+
import warnings
|
6 |
+
from distutils.version import LooseVersion
|
7 |
+
from itertools import repeat
|
8 |
+
from torch import nn as nn
|
9 |
+
from torch.nn import functional as F
|
10 |
+
from torch.nn import init as init
|
11 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
12 |
+
|
13 |
+
from codeformer.basicsr.ops.dcn import ModulatedDeformConvPack, modulated_deform_conv
|
14 |
+
from codeformer.basicsr.utils import get_root_logger
|
15 |
+
|
16 |
+
|
17 |
+
@torch.no_grad()
|
18 |
+
def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
|
19 |
+
"""Initialize network weights.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
module_list (list[nn.Module] | nn.Module): Modules to be initialized.
|
23 |
+
scale (float): Scale initialized weights, especially for residual
|
24 |
+
blocks. Default: 1.
|
25 |
+
bias_fill (float): The value to fill bias. Default: 0
|
26 |
+
kwargs (dict): Other arguments for initialization function.
|
27 |
+
"""
|
28 |
+
if not isinstance(module_list, list):
|
29 |
+
module_list = [module_list]
|
30 |
+
for module in module_list:
|
31 |
+
for m in module.modules():
|
32 |
+
if isinstance(m, nn.Conv2d):
|
33 |
+
init.kaiming_normal_(m.weight, **kwargs)
|
34 |
+
m.weight.data *= scale
|
35 |
+
if m.bias is not None:
|
36 |
+
m.bias.data.fill_(bias_fill)
|
37 |
+
elif isinstance(m, nn.Linear):
|
38 |
+
init.kaiming_normal_(m.weight, **kwargs)
|
39 |
+
m.weight.data *= scale
|
40 |
+
if m.bias is not None:
|
41 |
+
m.bias.data.fill_(bias_fill)
|
42 |
+
elif isinstance(m, _BatchNorm):
|
43 |
+
init.constant_(m.weight, 1)
|
44 |
+
if m.bias is not None:
|
45 |
+
m.bias.data.fill_(bias_fill)
|
46 |
+
|
47 |
+
|
48 |
+
def make_layer(basic_block, num_basic_block, **kwarg):
|
49 |
+
"""Make layers by stacking the same blocks.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
basic_block (nn.module): nn.module class for basic block.
|
53 |
+
num_basic_block (int): number of blocks.
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
nn.Sequential: Stacked blocks in nn.Sequential.
|
57 |
+
"""
|
58 |
+
layers = []
|
59 |
+
for _ in range(num_basic_block):
|
60 |
+
layers.append(basic_block(**kwarg))
|
61 |
+
return nn.Sequential(*layers)
|
62 |
+
|
63 |
+
|
64 |
+
class ResidualBlockNoBN(nn.Module):
|
65 |
+
"""Residual block without BN.
|
66 |
+
|
67 |
+
It has a style of:
|
68 |
+
---Conv-ReLU-Conv-+-
|
69 |
+
|________________|
|
70 |
+
|
71 |
+
Args:
|
72 |
+
num_feat (int): Channel number of intermediate features.
|
73 |
+
Default: 64.
|
74 |
+
res_scale (float): Residual scale. Default: 1.
|
75 |
+
pytorch_init (bool): If set to True, use pytorch default init,
|
76 |
+
otherwise, use default_init_weights. Default: False.
|
77 |
+
"""
|
78 |
+
|
79 |
+
def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
|
80 |
+
super(ResidualBlockNoBN, self).__init__()
|
81 |
+
self.res_scale = res_scale
|
82 |
+
self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
83 |
+
self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
84 |
+
self.relu = nn.ReLU(inplace=True)
|
85 |
+
|
86 |
+
if not pytorch_init:
|
87 |
+
default_init_weights([self.conv1, self.conv2], 0.1)
|
88 |
+
|
89 |
+
def forward(self, x):
|
90 |
+
identity = x
|
91 |
+
out = self.conv2(self.relu(self.conv1(x)))
|
92 |
+
return identity + out * self.res_scale
|
93 |
+
|
94 |
+
|
95 |
+
class Upsample(nn.Sequential):
|
96 |
+
"""Upsample module.
|
97 |
+
|
98 |
+
Args:
|
99 |
+
scale (int): Scale factor. Supported scales: 2^n and 3.
|
100 |
+
num_feat (int): Channel number of intermediate features.
|
101 |
+
"""
|
102 |
+
|
103 |
+
def __init__(self, scale, num_feat):
|
104 |
+
m = []
|
105 |
+
if (scale & (scale - 1)) == 0: # scale = 2^n
|
106 |
+
for _ in range(int(math.log(scale, 2))):
|
107 |
+
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
|
108 |
+
m.append(nn.PixelShuffle(2))
|
109 |
+
elif scale == 3:
|
110 |
+
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
|
111 |
+
m.append(nn.PixelShuffle(3))
|
112 |
+
else:
|
113 |
+
raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
|
114 |
+
super(Upsample, self).__init__(*m)
|
115 |
+
|
116 |
+
|
117 |
+
def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True):
|
118 |
+
"""Warp an image or feature map with optical flow.
|
119 |
+
|
120 |
+
Args:
|
121 |
+
x (Tensor): Tensor with size (n, c, h, w).
|
122 |
+
flow (Tensor): Tensor with size (n, h, w, 2), normal value.
|
123 |
+
interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
|
124 |
+
padding_mode (str): 'zeros' or 'border' or 'reflection'.
|
125 |
+
Default: 'zeros'.
|
126 |
+
align_corners (bool): Before pytorch 1.3, the default value is
|
127 |
+
align_corners=True. After pytorch 1.3, the default value is
|
128 |
+
align_corners=False. Here, we use the True as default.
|
129 |
+
|
130 |
+
Returns:
|
131 |
+
Tensor: Warped image or feature map.
|
132 |
+
"""
|
133 |
+
assert x.size()[-2:] == flow.size()[1:3]
|
134 |
+
_, _, h, w = x.size()
|
135 |
+
# create mesh grid
|
136 |
+
grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x))
|
137 |
+
grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
|
138 |
+
grid.requires_grad = False
|
139 |
+
|
140 |
+
vgrid = grid + flow
|
141 |
+
# scale grid to [-1,1]
|
142 |
+
vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
|
143 |
+
vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
|
144 |
+
vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
|
145 |
+
output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners)
|
146 |
+
|
147 |
+
# TODO, what if align_corners=False
|
148 |
+
return output
|
149 |
+
|
150 |
+
|
151 |
+
def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False):
|
152 |
+
"""Resize a flow according to ratio or shape.
|
153 |
+
|
154 |
+
Args:
|
155 |
+
flow (Tensor): Precomputed flow. shape [N, 2, H, W].
|
156 |
+
size_type (str): 'ratio' or 'shape'.
|
157 |
+
sizes (list[int | float]): the ratio for resizing or the final output
|
158 |
+
shape.
|
159 |
+
1) The order of ratio should be [ratio_h, ratio_w]. For
|
160 |
+
downsampling, the ratio should be smaller than 1.0 (i.e., ratio
|
161 |
+
< 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
|
162 |
+
ratio > 1.0).
|
163 |
+
2) The order of output_size should be [out_h, out_w].
|
164 |
+
interp_mode (str): The mode of interpolation for resizing.
|
165 |
+
Default: 'bilinear'.
|
166 |
+
align_corners (bool): Whether align corners. Default: False.
|
167 |
+
|
168 |
+
Returns:
|
169 |
+
Tensor: Resized flow.
|
170 |
+
"""
|
171 |
+
_, _, flow_h, flow_w = flow.size()
|
172 |
+
if size_type == 'ratio':
|
173 |
+
output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
|
174 |
+
elif size_type == 'shape':
|
175 |
+
output_h, output_w = sizes[0], sizes[1]
|
176 |
+
else:
|
177 |
+
raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.')
|
178 |
+
|
179 |
+
input_flow = flow.clone()
|
180 |
+
ratio_h = output_h / flow_h
|
181 |
+
ratio_w = output_w / flow_w
|
182 |
+
input_flow[:, 0, :, :] *= ratio_w
|
183 |
+
input_flow[:, 1, :, :] *= ratio_h
|
184 |
+
resized_flow = F.interpolate(
|
185 |
+
input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners)
|
186 |
+
return resized_flow
|
187 |
+
|
188 |
+
|
189 |
+
# TODO: may write a cpp file
|
190 |
+
def pixel_unshuffle(x, scale):
|
191 |
+
""" Pixel unshuffle.
|
192 |
+
|
193 |
+
Args:
|
194 |
+
x (Tensor): Input feature with shape (b, c, hh, hw).
|
195 |
+
scale (int): Downsample ratio.
|
196 |
+
|
197 |
+
Returns:
|
198 |
+
Tensor: the pixel unshuffled feature.
|
199 |
+
"""
|
200 |
+
b, c, hh, hw = x.size()
|
201 |
+
out_channel = c * (scale**2)
|
202 |
+
assert hh % scale == 0 and hw % scale == 0
|
203 |
+
h = hh // scale
|
204 |
+
w = hw // scale
|
205 |
+
x_view = x.view(b, c, h, scale, w, scale)
|
206 |
+
return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
|
207 |
+
|
208 |
+
|
209 |
+
class DCNv2Pack(ModulatedDeformConvPack):
|
210 |
+
"""Modulated deformable conv for deformable alignment.
|
211 |
+
|
212 |
+
Different from the official DCNv2Pack, which generates offsets and masks
|
213 |
+
from the preceding features, this DCNv2Pack takes another different
|
214 |
+
features to generate offsets and masks.
|
215 |
+
|
216 |
+
Ref:
|
217 |
+
Delving Deep into Deformable Alignment in Video Super-Resolution.
|
218 |
+
"""
|
219 |
+
|
220 |
+
def forward(self, x, feat):
|
221 |
+
out = self.conv_offset(feat)
|
222 |
+
o1, o2, mask = torch.chunk(out, 3, dim=1)
|
223 |
+
offset = torch.cat((o1, o2), dim=1)
|
224 |
+
mask = torch.sigmoid(mask)
|
225 |
+
|
226 |
+
offset_absmean = torch.mean(torch.abs(offset))
|
227 |
+
if offset_absmean > 50:
|
228 |
+
logger = get_root_logger()
|
229 |
+
logger.warning(f'Offset abs mean is {offset_absmean}, larger than 50.')
|
230 |
+
|
231 |
+
if LooseVersion(torchvision.__version__) >= LooseVersion('0.9.0'):
|
232 |
+
return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding,
|
233 |
+
self.dilation, mask)
|
234 |
+
else:
|
235 |
+
return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding,
|
236 |
+
self.dilation, self.groups, self.deformable_groups)
|
237 |
+
|
238 |
+
|
239 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
240 |
+
# From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
|
241 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
242 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
243 |
+
def norm_cdf(x):
|
244 |
+
# Computes standard normal cumulative distribution function
|
245 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
246 |
+
|
247 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
248 |
+
warnings.warn(
|
249 |
+
'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
|
250 |
+
'The distribution of values may be incorrect.',
|
251 |
+
stacklevel=2)
|
252 |
+
|
253 |
+
with torch.no_grad():
|
254 |
+
# Values are generated by using a truncated uniform distribution and
|
255 |
+
# then using the inverse CDF for the normal distribution.
|
256 |
+
# Get upper and lower cdf values
|
257 |
+
low = norm_cdf((a - mean) / std)
|
258 |
+
up = norm_cdf((b - mean) / std)
|
259 |
+
|
260 |
+
# Uniformly fill tensor with values from [low, up], then translate to
|
261 |
+
# [2l-1, 2u-1].
|
262 |
+
tensor.uniform_(2 * low - 1, 2 * up - 1)
|
263 |
+
|
264 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
265 |
+
# standard normal
|
266 |
+
tensor.erfinv_()
|
267 |
+
|
268 |
+
# Transform to proper mean, std
|
269 |
+
tensor.mul_(std * math.sqrt(2.))
|
270 |
+
tensor.add_(mean)
|
271 |
+
|
272 |
+
# Clamp to ensure it's in the proper range
|
273 |
+
tensor.clamp_(min=a, max=b)
|
274 |
+
return tensor
|
275 |
+
|
276 |
+
|
277 |
+
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
278 |
+
r"""Fills the input Tensor with values drawn from a truncated
|
279 |
+
normal distribution.
|
280 |
+
|
281 |
+
From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
|
282 |
+
|
283 |
+
The values are effectively drawn from the
|
284 |
+
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
285 |
+
with values outside :math:`[a, b]` redrawn until they are within
|
286 |
+
the bounds. The method used for generating the random values works
|
287 |
+
best when :math:`a \leq \text{mean} \leq b`.
|
288 |
+
|
289 |
+
Args:
|
290 |
+
tensor: an n-dimensional `torch.Tensor`
|
291 |
+
mean: the mean of the normal distribution
|
292 |
+
std: the standard deviation of the normal distribution
|
293 |
+
a: the minimum cutoff value
|
294 |
+
b: the maximum cutoff value
|
295 |
+
|
296 |
+
Examples:
|
297 |
+
>>> w = torch.empty(3, 5)
|
298 |
+
>>> nn.init.trunc_normal_(w)
|
299 |
+
"""
|
300 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
301 |
+
|
302 |
+
|
303 |
+
# From PyTorch
|
304 |
+
def _ntuple(n):
|
305 |
+
|
306 |
+
def parse(x):
|
307 |
+
if isinstance(x, collections.abc.Iterable):
|
308 |
+
return x
|
309 |
+
return tuple(repeat(x, n))
|
310 |
+
|
311 |
+
return parse
|
312 |
+
|
313 |
+
|
314 |
+
to_1tuple = _ntuple(1)
|
315 |
+
to_2tuple = _ntuple(2)
|
316 |
+
to_3tuple = _ntuple(3)
|
317 |
+
to_4tuple = _ntuple(4)
|
318 |
+
to_ntuple = _ntuple
|
blissful_tuner/codeformer/basicsr/archs/codeformer_arch.py
ADDED
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from torch import nn, Tensor
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from typing import Optional, List
|
7 |
+
|
8 |
+
from codeformer.basicsr.archs.vqgan_arch import *
|
9 |
+
from codeformer.basicsr.utils import get_root_logger
|
10 |
+
from codeformer.basicsr.utils.registry import ARCH_REGISTRY
|
11 |
+
|
12 |
+
def calc_mean_std(feat, eps=1e-5):
|
13 |
+
"""Calculate mean and std for adaptive_instance_normalization.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
feat (Tensor): 4D tensor.
|
17 |
+
eps (float): A small value added to the variance to avoid
|
18 |
+
divide-by-zero. Default: 1e-5.
|
19 |
+
"""
|
20 |
+
size = feat.size()
|
21 |
+
assert len(size) == 4, 'The input feature should be 4D tensor.'
|
22 |
+
b, c = size[:2]
|
23 |
+
feat_var = feat.view(b, c, -1).var(dim=2) + eps
|
24 |
+
feat_std = feat_var.sqrt().view(b, c, 1, 1)
|
25 |
+
feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
|
26 |
+
return feat_mean, feat_std
|
27 |
+
|
28 |
+
|
29 |
+
def adaptive_instance_normalization(content_feat, style_feat):
|
30 |
+
"""Adaptive instance normalization.
|
31 |
+
|
32 |
+
Adjust the reference features to have the similar color and illuminations
|
33 |
+
as those in the degradate features.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
content_feat (Tensor): The reference feature.
|
37 |
+
style_feat (Tensor): The degradate features.
|
38 |
+
"""
|
39 |
+
size = content_feat.size()
|
40 |
+
style_mean, style_std = calc_mean_std(style_feat)
|
41 |
+
content_mean, content_std = calc_mean_std(content_feat)
|
42 |
+
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
43 |
+
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
44 |
+
|
45 |
+
|
46 |
+
class PositionEmbeddingSine(nn.Module):
|
47 |
+
"""
|
48 |
+
This is a more standard version of the position embedding, very similar to the one
|
49 |
+
used by the Attention is all you need paper, generalized to work on images.
|
50 |
+
"""
|
51 |
+
|
52 |
+
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
53 |
+
super().__init__()
|
54 |
+
self.num_pos_feats = num_pos_feats
|
55 |
+
self.temperature = temperature
|
56 |
+
self.normalize = normalize
|
57 |
+
if scale is not None and normalize is False:
|
58 |
+
raise ValueError("normalize should be True if scale is passed")
|
59 |
+
if scale is None:
|
60 |
+
scale = 2 * math.pi
|
61 |
+
self.scale = scale
|
62 |
+
|
63 |
+
def forward(self, x, mask=None):
|
64 |
+
if mask is None:
|
65 |
+
mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
|
66 |
+
not_mask = ~mask
|
67 |
+
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
68 |
+
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
69 |
+
if self.normalize:
|
70 |
+
eps = 1e-6
|
71 |
+
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
72 |
+
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
73 |
+
|
74 |
+
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
75 |
+
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
76 |
+
|
77 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
78 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
79 |
+
pos_x = torch.stack(
|
80 |
+
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
|
81 |
+
).flatten(3)
|
82 |
+
pos_y = torch.stack(
|
83 |
+
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
|
84 |
+
).flatten(3)
|
85 |
+
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
86 |
+
return pos
|
87 |
+
|
88 |
+
def _get_activation_fn(activation):
|
89 |
+
"""Return an activation function given a string"""
|
90 |
+
if activation == "relu":
|
91 |
+
return F.relu
|
92 |
+
if activation == "gelu":
|
93 |
+
return F.gelu
|
94 |
+
if activation == "glu":
|
95 |
+
return F.glu
|
96 |
+
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
97 |
+
|
98 |
+
|
99 |
+
class TransformerSALayer(nn.Module):
|
100 |
+
def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
|
101 |
+
super().__init__()
|
102 |
+
self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
|
103 |
+
# Implementation of Feedforward model - MLP
|
104 |
+
self.linear1 = nn.Linear(embed_dim, dim_mlp)
|
105 |
+
self.dropout = nn.Dropout(dropout)
|
106 |
+
self.linear2 = nn.Linear(dim_mlp, embed_dim)
|
107 |
+
|
108 |
+
self.norm1 = nn.LayerNorm(embed_dim)
|
109 |
+
self.norm2 = nn.LayerNorm(embed_dim)
|
110 |
+
self.dropout1 = nn.Dropout(dropout)
|
111 |
+
self.dropout2 = nn.Dropout(dropout)
|
112 |
+
|
113 |
+
self.activation = _get_activation_fn(activation)
|
114 |
+
|
115 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
116 |
+
return tensor if pos is None else tensor + pos
|
117 |
+
|
118 |
+
def forward(self, tgt,
|
119 |
+
tgt_mask: Optional[Tensor] = None,
|
120 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
121 |
+
query_pos: Optional[Tensor] = None):
|
122 |
+
|
123 |
+
# self attention
|
124 |
+
tgt2 = self.norm1(tgt)
|
125 |
+
q = k = self.with_pos_embed(tgt2, query_pos)
|
126 |
+
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
|
127 |
+
key_padding_mask=tgt_key_padding_mask)[0]
|
128 |
+
tgt = tgt + self.dropout1(tgt2)
|
129 |
+
|
130 |
+
# ffn
|
131 |
+
tgt2 = self.norm2(tgt)
|
132 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
133 |
+
tgt = tgt + self.dropout2(tgt2)
|
134 |
+
return tgt
|
135 |
+
|
136 |
+
class Fuse_sft_block(nn.Module):
|
137 |
+
def __init__(self, in_ch, out_ch):
|
138 |
+
super().__init__()
|
139 |
+
self.encode_enc = ResBlock(2*in_ch, out_ch)
|
140 |
+
|
141 |
+
self.scale = nn.Sequential(
|
142 |
+
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
|
143 |
+
nn.LeakyReLU(0.2, True),
|
144 |
+
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
|
145 |
+
|
146 |
+
self.shift = nn.Sequential(
|
147 |
+
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
|
148 |
+
nn.LeakyReLU(0.2, True),
|
149 |
+
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
|
150 |
+
|
151 |
+
def forward(self, enc_feat, dec_feat, w=1):
|
152 |
+
enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
|
153 |
+
scale = self.scale(enc_feat)
|
154 |
+
shift = self.shift(enc_feat)
|
155 |
+
residual = w * (dec_feat * scale + shift)
|
156 |
+
out = dec_feat + residual
|
157 |
+
return out
|
158 |
+
|
159 |
+
|
160 |
+
@ARCH_REGISTRY.register()
|
161 |
+
class CodeFormer(VQAutoEncoder):
|
162 |
+
def __init__(self, dim_embd=512, n_head=8, n_layers=9,
|
163 |
+
codebook_size=1024, latent_size=256,
|
164 |
+
connect_list=['32', '64', '128', '256'],
|
165 |
+
fix_modules=['quantize','generator'], vqgan_path=None):
|
166 |
+
super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
|
167 |
+
|
168 |
+
if vqgan_path is not None:
|
169 |
+
self.load_state_dict(
|
170 |
+
torch.load(vqgan_path, map_location='cpu')['params_ema'])
|
171 |
+
|
172 |
+
if fix_modules is not None:
|
173 |
+
for module in fix_modules:
|
174 |
+
for param in getattr(self, module).parameters():
|
175 |
+
param.requires_grad = False
|
176 |
+
|
177 |
+
self.connect_list = connect_list
|
178 |
+
self.n_layers = n_layers
|
179 |
+
self.dim_embd = dim_embd
|
180 |
+
self.dim_mlp = dim_embd*2
|
181 |
+
|
182 |
+
self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
|
183 |
+
self.feat_emb = nn.Linear(256, self.dim_embd)
|
184 |
+
|
185 |
+
# transformer
|
186 |
+
self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
|
187 |
+
for _ in range(self.n_layers)])
|
188 |
+
|
189 |
+
# logits_predict head
|
190 |
+
self.idx_pred_layer = nn.Sequential(
|
191 |
+
nn.LayerNorm(dim_embd),
|
192 |
+
nn.Linear(dim_embd, codebook_size, bias=False))
|
193 |
+
|
194 |
+
self.channels = {
|
195 |
+
'16': 512,
|
196 |
+
'32': 256,
|
197 |
+
'64': 256,
|
198 |
+
'128': 128,
|
199 |
+
'256': 128,
|
200 |
+
'512': 64,
|
201 |
+
}
|
202 |
+
|
203 |
+
# after second residual block for > 16, before attn layer for ==16
|
204 |
+
self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18}
|
205 |
+
# after first residual block for > 16, before attn layer for ==16
|
206 |
+
self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21}
|
207 |
+
|
208 |
+
# fuse_convs_dict
|
209 |
+
self.fuse_convs_dict = nn.ModuleDict()
|
210 |
+
for f_size in self.connect_list:
|
211 |
+
in_ch = self.channels[f_size]
|
212 |
+
self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
|
213 |
+
|
214 |
+
def _init_weights(self, module):
|
215 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
216 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
217 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
218 |
+
module.bias.data.zero_()
|
219 |
+
elif isinstance(module, nn.LayerNorm):
|
220 |
+
module.bias.data.zero_()
|
221 |
+
module.weight.data.fill_(1.0)
|
222 |
+
|
223 |
+
def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
|
224 |
+
# ################### Encoder #####################
|
225 |
+
enc_feat_dict = {}
|
226 |
+
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
|
227 |
+
for i, block in enumerate(self.encoder.blocks):
|
228 |
+
x = block(x)
|
229 |
+
if i in out_list:
|
230 |
+
enc_feat_dict[str(x.shape[-1])] = x.clone()
|
231 |
+
|
232 |
+
lq_feat = x
|
233 |
+
# ################# Transformer ###################
|
234 |
+
# quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
|
235 |
+
pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1)
|
236 |
+
# BCHW -> BC(HW) -> (HW)BC
|
237 |
+
feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1))
|
238 |
+
query_emb = feat_emb
|
239 |
+
# Transformer encoder
|
240 |
+
for layer in self.ft_layers:
|
241 |
+
query_emb = layer(query_emb, query_pos=pos_emb)
|
242 |
+
|
243 |
+
# output logits
|
244 |
+
logits = self.idx_pred_layer(query_emb) # (hw)bn
|
245 |
+
logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n
|
246 |
+
|
247 |
+
if code_only: # for training stage II
|
248 |
+
# logits doesn't need softmax before cross_entropy loss
|
249 |
+
return logits, lq_feat
|
250 |
+
|
251 |
+
# ################# Quantization ###################
|
252 |
+
# if self.training:
|
253 |
+
# quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
|
254 |
+
# # b(hw)c -> bc(hw) -> bchw
|
255 |
+
# quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
|
256 |
+
# ------------
|
257 |
+
soft_one_hot = F.softmax(logits, dim=2)
|
258 |
+
_, top_idx = torch.topk(soft_one_hot, 1, dim=2)
|
259 |
+
quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256])
|
260 |
+
# preserve gradients
|
261 |
+
# quant_feat = lq_feat + (quant_feat - lq_feat).detach()
|
262 |
+
|
263 |
+
if detach_16:
|
264 |
+
quant_feat = quant_feat.detach() # for training stage III
|
265 |
+
if adain:
|
266 |
+
quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
|
267 |
+
|
268 |
+
# ################## Generator ####################
|
269 |
+
x = quant_feat
|
270 |
+
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
|
271 |
+
|
272 |
+
for i, block in enumerate(self.generator.blocks):
|
273 |
+
x = block(x)
|
274 |
+
if i in fuse_list: # fuse after i-th block
|
275 |
+
f_size = str(x.shape[-1])
|
276 |
+
if w>0:
|
277 |
+
x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
|
278 |
+
out = x
|
279 |
+
# logits doesn't need softmax before cross_entropy loss
|
280 |
+
return out, logits, lq_feat
|
blissful_tuner/codeformer/basicsr/archs/rrdbnet_arch.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn as nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
from codeformer.basicsr.utils.registry import ARCH_REGISTRY
|
6 |
+
from .arch_util import default_init_weights, make_layer, pixel_unshuffle
|
7 |
+
|
8 |
+
|
9 |
+
class ResidualDenseBlock(nn.Module):
|
10 |
+
"""Residual Dense Block.
|
11 |
+
|
12 |
+
Used in RRDB block in ESRGAN.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
num_feat (int): Channel number of intermediate features.
|
16 |
+
num_grow_ch (int): Channels for each growth.
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, num_feat=64, num_grow_ch=32):
|
20 |
+
super(ResidualDenseBlock, self).__init__()
|
21 |
+
self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
|
22 |
+
self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
|
23 |
+
self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
|
24 |
+
self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
|
25 |
+
self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
|
26 |
+
|
27 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
28 |
+
|
29 |
+
# initialization
|
30 |
+
default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
31 |
+
|
32 |
+
def forward(self, x):
|
33 |
+
x1 = self.lrelu(self.conv1(x))
|
34 |
+
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
35 |
+
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
36 |
+
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
37 |
+
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
38 |
+
# Emperically, we use 0.2 to scale the residual for better performance
|
39 |
+
return x5 * 0.2 + x
|
40 |
+
|
41 |
+
|
42 |
+
class RRDB(nn.Module):
|
43 |
+
"""Residual in Residual Dense Block.
|
44 |
+
|
45 |
+
Used in RRDB-Net in ESRGAN.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
num_feat (int): Channel number of intermediate features.
|
49 |
+
num_grow_ch (int): Channels for each growth.
|
50 |
+
"""
|
51 |
+
|
52 |
+
def __init__(self, num_feat, num_grow_ch=32):
|
53 |
+
super(RRDB, self).__init__()
|
54 |
+
self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
|
55 |
+
self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
|
56 |
+
self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
|
57 |
+
|
58 |
+
def forward(self, x):
|
59 |
+
out = self.rdb1(x)
|
60 |
+
out = self.rdb2(out)
|
61 |
+
out = self.rdb3(out)
|
62 |
+
# Emperically, we use 0.2 to scale the residual for better performance
|
63 |
+
return out * 0.2 + x
|
64 |
+
|
65 |
+
|
66 |
+
@ARCH_REGISTRY.register()
|
67 |
+
class RRDBNet(nn.Module):
|
68 |
+
"""Networks consisting of Residual in Residual Dense Block, which is used
|
69 |
+
in ESRGAN.
|
70 |
+
|
71 |
+
ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
|
72 |
+
|
73 |
+
We extend ESRGAN for scale x2 and scale x1.
|
74 |
+
Note: This is one option for scale 1, scale 2 in RRDBNet.
|
75 |
+
We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
|
76 |
+
and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
|
77 |
+
|
78 |
+
Args:
|
79 |
+
num_in_ch (int): Channel number of inputs.
|
80 |
+
num_out_ch (int): Channel number of outputs.
|
81 |
+
num_feat (int): Channel number of intermediate features.
|
82 |
+
Default: 64
|
83 |
+
num_block (int): Block number in the trunk network. Defaults: 23
|
84 |
+
num_grow_ch (int): Channels for each growth. Default: 32.
|
85 |
+
"""
|
86 |
+
|
87 |
+
def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
|
88 |
+
super(RRDBNet, self).__init__()
|
89 |
+
self.scale = scale
|
90 |
+
if scale == 2:
|
91 |
+
num_in_ch = num_in_ch * 4
|
92 |
+
elif scale == 1:
|
93 |
+
num_in_ch = num_in_ch * 16
|
94 |
+
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
|
95 |
+
self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
|
96 |
+
self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
97 |
+
# upsample
|
98 |
+
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
99 |
+
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
100 |
+
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
101 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
102 |
+
|
103 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
104 |
+
|
105 |
+
def forward(self, x):
|
106 |
+
if self.scale == 2:
|
107 |
+
feat = pixel_unshuffle(x, scale=2)
|
108 |
+
elif self.scale == 1:
|
109 |
+
feat = pixel_unshuffle(x, scale=4)
|
110 |
+
else:
|
111 |
+
feat = x
|
112 |
+
feat = self.conv_first(feat)
|
113 |
+
body_feat = self.conv_body(self.body(feat))
|
114 |
+
feat = feat + body_feat
|
115 |
+
# upsample
|
116 |
+
feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
|
117 |
+
feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
|
118 |
+
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
|
119 |
+
return out
|
blissful_tuner/codeformer/basicsr/archs/vgg_arch.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from collections import OrderedDict
|
4 |
+
from torch import nn as nn
|
5 |
+
from torchvision.models import vgg as vgg
|
6 |
+
|
7 |
+
from codeformer.basicsr.utils.registry import ARCH_REGISTRY
|
8 |
+
|
9 |
+
VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth'
|
10 |
+
NAMES = {
|
11 |
+
'vgg11': [
|
12 |
+
'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2',
|
13 |
+
'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2',
|
14 |
+
'pool5'
|
15 |
+
],
|
16 |
+
'vgg13': [
|
17 |
+
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
|
18 |
+
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4',
|
19 |
+
'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5'
|
20 |
+
],
|
21 |
+
'vgg16': [
|
22 |
+
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
|
23 |
+
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2',
|
24 |
+
'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
|
25 |
+
'pool5'
|
26 |
+
],
|
27 |
+
'vgg19': [
|
28 |
+
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
|
29 |
+
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1',
|
30 |
+
'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1',
|
31 |
+
'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5'
|
32 |
+
]
|
33 |
+
}
|
34 |
+
|
35 |
+
|
36 |
+
def insert_bn(names):
|
37 |
+
"""Insert bn layer after each conv.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
names (list): The list of layer names.
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
list: The list of layer names with bn layers.
|
44 |
+
"""
|
45 |
+
names_bn = []
|
46 |
+
for name in names:
|
47 |
+
names_bn.append(name)
|
48 |
+
if 'conv' in name:
|
49 |
+
position = name.replace('conv', '')
|
50 |
+
names_bn.append('bn' + position)
|
51 |
+
return names_bn
|
52 |
+
|
53 |
+
|
54 |
+
@ARCH_REGISTRY.register()
|
55 |
+
class VGGFeatureExtractor(nn.Module):
|
56 |
+
"""VGG network for feature extraction.
|
57 |
+
|
58 |
+
In this implementation, we allow users to choose whether use normalization
|
59 |
+
in the input feature and the type of vgg network. Note that the pretrained
|
60 |
+
path must fit the vgg type.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
layer_name_list (list[str]): Forward function returns the corresponding
|
64 |
+
features according to the layer_name_list.
|
65 |
+
Example: {'relu1_1', 'relu2_1', 'relu3_1'}.
|
66 |
+
vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
|
67 |
+
use_input_norm (bool): If True, normalize the input image. Importantly,
|
68 |
+
the input feature must in the range [0, 1]. Default: True.
|
69 |
+
range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
|
70 |
+
Default: False.
|
71 |
+
requires_grad (bool): If true, the parameters of VGG network will be
|
72 |
+
optimized. Default: False.
|
73 |
+
remove_pooling (bool): If true, the max pooling operations in VGG net
|
74 |
+
will be removed. Default: False.
|
75 |
+
pooling_stride (int): The stride of max pooling operation. Default: 2.
|
76 |
+
"""
|
77 |
+
|
78 |
+
def __init__(self,
|
79 |
+
layer_name_list,
|
80 |
+
vgg_type='vgg19',
|
81 |
+
use_input_norm=True,
|
82 |
+
range_norm=False,
|
83 |
+
requires_grad=False,
|
84 |
+
remove_pooling=False,
|
85 |
+
pooling_stride=2):
|
86 |
+
super(VGGFeatureExtractor, self).__init__()
|
87 |
+
|
88 |
+
self.layer_name_list = layer_name_list
|
89 |
+
self.use_input_norm = use_input_norm
|
90 |
+
self.range_norm = range_norm
|
91 |
+
|
92 |
+
self.names = NAMES[vgg_type.replace('_bn', '')]
|
93 |
+
if 'bn' in vgg_type:
|
94 |
+
self.names = insert_bn(self.names)
|
95 |
+
|
96 |
+
# only borrow layers that will be used to avoid unused params
|
97 |
+
max_idx = 0
|
98 |
+
for v in layer_name_list:
|
99 |
+
idx = self.names.index(v)
|
100 |
+
if idx > max_idx:
|
101 |
+
max_idx = idx
|
102 |
+
|
103 |
+
if os.path.exists(VGG_PRETRAIN_PATH):
|
104 |
+
vgg_net = getattr(vgg, vgg_type)(pretrained=False)
|
105 |
+
state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage)
|
106 |
+
vgg_net.load_state_dict(state_dict)
|
107 |
+
else:
|
108 |
+
vgg_net = getattr(vgg, vgg_type)(pretrained=True)
|
109 |
+
|
110 |
+
features = vgg_net.features[:max_idx + 1]
|
111 |
+
|
112 |
+
modified_net = OrderedDict()
|
113 |
+
for k, v in zip(self.names, features):
|
114 |
+
if 'pool' in k:
|
115 |
+
# if remove_pooling is true, pooling operation will be removed
|
116 |
+
if remove_pooling:
|
117 |
+
continue
|
118 |
+
else:
|
119 |
+
# in some cases, we may want to change the default stride
|
120 |
+
modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride)
|
121 |
+
else:
|
122 |
+
modified_net[k] = v
|
123 |
+
|
124 |
+
self.vgg_net = nn.Sequential(modified_net)
|
125 |
+
|
126 |
+
if not requires_grad:
|
127 |
+
self.vgg_net.eval()
|
128 |
+
for param in self.parameters():
|
129 |
+
param.requires_grad = False
|
130 |
+
else:
|
131 |
+
self.vgg_net.train()
|
132 |
+
for param in self.parameters():
|
133 |
+
param.requires_grad = True
|
134 |
+
|
135 |
+
if self.use_input_norm:
|
136 |
+
# the mean is for image with range [0, 1]
|
137 |
+
self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
138 |
+
# the std is for image with range [0, 1]
|
139 |
+
self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
140 |
+
|
141 |
+
def forward(self, x):
|
142 |
+
"""Forward function.
|
143 |
+
|
144 |
+
Args:
|
145 |
+
x (Tensor): Input tensor with shape (n, c, h, w).
|
146 |
+
|
147 |
+
Returns:
|
148 |
+
Tensor: Forward results.
|
149 |
+
"""
|
150 |
+
if self.range_norm:
|
151 |
+
x = (x + 1) / 2
|
152 |
+
if self.use_input_norm:
|
153 |
+
x = (x - self.mean) / self.std
|
154 |
+
output = {}
|
155 |
+
|
156 |
+
for key, layer in self.vgg_net._modules.items():
|
157 |
+
x = layer(x)
|
158 |
+
if key in self.layer_name_list:
|
159 |
+
output[key] = x.clone()
|
160 |
+
|
161 |
+
return output
|
blissful_tuner/codeformer/basicsr/archs/vqgan_arch.py
ADDED
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
VQGAN code, adapted from the original created by the Unleashing Transformers authors:
|
3 |
+
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
|
4 |
+
|
5 |
+
'''
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import copy
|
11 |
+
from codeformer.basicsr.utils import get_root_logger
|
12 |
+
from codeformer.basicsr.utils.registry import ARCH_REGISTRY
|
13 |
+
|
14 |
+
def normalize(in_channels):
|
15 |
+
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
16 |
+
|
17 |
+
|
18 |
+
@torch.jit.script
|
19 |
+
def swish(x):
|
20 |
+
return x*torch.sigmoid(x)
|
21 |
+
|
22 |
+
|
23 |
+
# Define VQVAE classes
|
24 |
+
class VectorQuantizer(nn.Module):
|
25 |
+
def __init__(self, codebook_size, emb_dim, beta):
|
26 |
+
super(VectorQuantizer, self).__init__()
|
27 |
+
self.codebook_size = codebook_size # number of embeddings
|
28 |
+
self.emb_dim = emb_dim # dimension of embedding
|
29 |
+
self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
|
30 |
+
self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
|
31 |
+
self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
|
32 |
+
|
33 |
+
def forward(self, z):
|
34 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
35 |
+
z = z.permute(0, 2, 3, 1).contiguous()
|
36 |
+
z_flattened = z.view(-1, self.emb_dim)
|
37 |
+
|
38 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
39 |
+
d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \
|
40 |
+
2 * torch.matmul(z_flattened, self.embedding.weight.t())
|
41 |
+
|
42 |
+
mean_distance = torch.mean(d)
|
43 |
+
# find closest encodings
|
44 |
+
min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
|
45 |
+
# min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False)
|
46 |
+
# [0-1], higher score, higher confidence
|
47 |
+
# min_encoding_scores = torch.exp(-min_encoding_scores/10)
|
48 |
+
|
49 |
+
min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z)
|
50 |
+
min_encodings.scatter_(1, min_encoding_indices, 1)
|
51 |
+
|
52 |
+
# get quantized latent vectors
|
53 |
+
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
|
54 |
+
# compute loss for embedding
|
55 |
+
loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
|
56 |
+
# preserve gradients
|
57 |
+
z_q = z + (z_q - z).detach()
|
58 |
+
|
59 |
+
# perplexity
|
60 |
+
e_mean = torch.mean(min_encodings, dim=0)
|
61 |
+
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
|
62 |
+
# reshape back to match original input shape
|
63 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
64 |
+
|
65 |
+
return z_q, loss, {
|
66 |
+
"perplexity": perplexity,
|
67 |
+
"min_encodings": min_encodings,
|
68 |
+
"min_encoding_indices": min_encoding_indices,
|
69 |
+
"mean_distance": mean_distance
|
70 |
+
}
|
71 |
+
|
72 |
+
def get_codebook_feat(self, indices, shape):
|
73 |
+
# input indices: batch*token_num -> (batch*token_num)*1
|
74 |
+
# shape: batch, height, width, channel
|
75 |
+
indices = indices.view(-1,1)
|
76 |
+
min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
|
77 |
+
min_encodings.scatter_(1, indices, 1)
|
78 |
+
# get quantized latent vectors
|
79 |
+
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
|
80 |
+
|
81 |
+
if shape is not None: # reshape back to match original input shape
|
82 |
+
z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
|
83 |
+
|
84 |
+
return z_q
|
85 |
+
|
86 |
+
|
87 |
+
class GumbelQuantizer(nn.Module):
|
88 |
+
def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0):
|
89 |
+
super().__init__()
|
90 |
+
self.codebook_size = codebook_size # number of embeddings
|
91 |
+
self.emb_dim = emb_dim # dimension of embedding
|
92 |
+
self.straight_through = straight_through
|
93 |
+
self.temperature = temp_init
|
94 |
+
self.kl_weight = kl_weight
|
95 |
+
self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits
|
96 |
+
self.embed = nn.Embedding(codebook_size, emb_dim)
|
97 |
+
|
98 |
+
def forward(self, z):
|
99 |
+
hard = self.straight_through if self.training else True
|
100 |
+
|
101 |
+
logits = self.proj(z)
|
102 |
+
|
103 |
+
soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
|
104 |
+
|
105 |
+
z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
|
106 |
+
|
107 |
+
# + kl divergence to the prior loss
|
108 |
+
qy = F.softmax(logits, dim=1)
|
109 |
+
diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
|
110 |
+
min_encoding_indices = soft_one_hot.argmax(dim=1)
|
111 |
+
|
112 |
+
return z_q, diff, {
|
113 |
+
"min_encoding_indices": min_encoding_indices
|
114 |
+
}
|
115 |
+
|
116 |
+
|
117 |
+
class Downsample(nn.Module):
|
118 |
+
def __init__(self, in_channels):
|
119 |
+
super().__init__()
|
120 |
+
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
121 |
+
|
122 |
+
def forward(self, x):
|
123 |
+
pad = (0, 1, 0, 1)
|
124 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
125 |
+
x = self.conv(x)
|
126 |
+
return x
|
127 |
+
|
128 |
+
|
129 |
+
class Upsample(nn.Module):
|
130 |
+
def __init__(self, in_channels):
|
131 |
+
super().__init__()
|
132 |
+
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
133 |
+
|
134 |
+
def forward(self, x):
|
135 |
+
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
136 |
+
x = self.conv(x)
|
137 |
+
|
138 |
+
return x
|
139 |
+
|
140 |
+
|
141 |
+
class ResBlock(nn.Module):
|
142 |
+
def __init__(self, in_channels, out_channels=None):
|
143 |
+
super(ResBlock, self).__init__()
|
144 |
+
self.in_channels = in_channels
|
145 |
+
self.out_channels = in_channels if out_channels is None else out_channels
|
146 |
+
self.norm1 = normalize(in_channels)
|
147 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
148 |
+
self.norm2 = normalize(out_channels)
|
149 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
150 |
+
if self.in_channels != self.out_channels:
|
151 |
+
self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
152 |
+
|
153 |
+
def forward(self, x_in):
|
154 |
+
x = x_in
|
155 |
+
x = self.norm1(x)
|
156 |
+
x = swish(x)
|
157 |
+
x = self.conv1(x)
|
158 |
+
x = self.norm2(x)
|
159 |
+
x = swish(x)
|
160 |
+
x = self.conv2(x)
|
161 |
+
if self.in_channels != self.out_channels:
|
162 |
+
x_in = self.conv_out(x_in)
|
163 |
+
|
164 |
+
return x + x_in
|
165 |
+
|
166 |
+
|
167 |
+
class AttnBlock(nn.Module):
|
168 |
+
def __init__(self, in_channels):
|
169 |
+
super().__init__()
|
170 |
+
self.in_channels = in_channels
|
171 |
+
|
172 |
+
self.norm = normalize(in_channels)
|
173 |
+
self.q = torch.nn.Conv2d(
|
174 |
+
in_channels,
|
175 |
+
in_channels,
|
176 |
+
kernel_size=1,
|
177 |
+
stride=1,
|
178 |
+
padding=0
|
179 |
+
)
|
180 |
+
self.k = torch.nn.Conv2d(
|
181 |
+
in_channels,
|
182 |
+
in_channels,
|
183 |
+
kernel_size=1,
|
184 |
+
stride=1,
|
185 |
+
padding=0
|
186 |
+
)
|
187 |
+
self.v = torch.nn.Conv2d(
|
188 |
+
in_channels,
|
189 |
+
in_channels,
|
190 |
+
kernel_size=1,
|
191 |
+
stride=1,
|
192 |
+
padding=0
|
193 |
+
)
|
194 |
+
self.proj_out = torch.nn.Conv2d(
|
195 |
+
in_channels,
|
196 |
+
in_channels,
|
197 |
+
kernel_size=1,
|
198 |
+
stride=1,
|
199 |
+
padding=0
|
200 |
+
)
|
201 |
+
|
202 |
+
def forward(self, x):
|
203 |
+
h_ = x
|
204 |
+
h_ = self.norm(h_)
|
205 |
+
q = self.q(h_)
|
206 |
+
k = self.k(h_)
|
207 |
+
v = self.v(h_)
|
208 |
+
|
209 |
+
# compute attention
|
210 |
+
b, c, h, w = q.shape
|
211 |
+
q = q.reshape(b, c, h*w)
|
212 |
+
q = q.permute(0, 2, 1)
|
213 |
+
k = k.reshape(b, c, h*w)
|
214 |
+
w_ = torch.bmm(q, k)
|
215 |
+
w_ = w_ * (int(c)**(-0.5))
|
216 |
+
w_ = F.softmax(w_, dim=2)
|
217 |
+
|
218 |
+
# attend to values
|
219 |
+
v = v.reshape(b, c, h*w)
|
220 |
+
w_ = w_.permute(0, 2, 1)
|
221 |
+
h_ = torch.bmm(v, w_)
|
222 |
+
h_ = h_.reshape(b, c, h, w)
|
223 |
+
|
224 |
+
h_ = self.proj_out(h_)
|
225 |
+
|
226 |
+
return x+h_
|
227 |
+
|
228 |
+
|
229 |
+
class Encoder(nn.Module):
|
230 |
+
def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions):
|
231 |
+
super().__init__()
|
232 |
+
self.nf = nf
|
233 |
+
self.num_resolutions = len(ch_mult)
|
234 |
+
self.num_res_blocks = num_res_blocks
|
235 |
+
self.resolution = resolution
|
236 |
+
self.attn_resolutions = attn_resolutions
|
237 |
+
|
238 |
+
curr_res = self.resolution
|
239 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
240 |
+
|
241 |
+
blocks = []
|
242 |
+
# initial convultion
|
243 |
+
blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
|
244 |
+
|
245 |
+
# residual and downsampling blocks, with attention on smaller res (16x16)
|
246 |
+
for i in range(self.num_resolutions):
|
247 |
+
block_in_ch = nf * in_ch_mult[i]
|
248 |
+
block_out_ch = nf * ch_mult[i]
|
249 |
+
for _ in range(self.num_res_blocks):
|
250 |
+
blocks.append(ResBlock(block_in_ch, block_out_ch))
|
251 |
+
block_in_ch = block_out_ch
|
252 |
+
if curr_res in attn_resolutions:
|
253 |
+
blocks.append(AttnBlock(block_in_ch))
|
254 |
+
|
255 |
+
if i != self.num_resolutions - 1:
|
256 |
+
blocks.append(Downsample(block_in_ch))
|
257 |
+
curr_res = curr_res // 2
|
258 |
+
|
259 |
+
# non-local attention block
|
260 |
+
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
261 |
+
blocks.append(AttnBlock(block_in_ch))
|
262 |
+
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
263 |
+
|
264 |
+
# normalise and convert to latent size
|
265 |
+
blocks.append(normalize(block_in_ch))
|
266 |
+
blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1))
|
267 |
+
self.blocks = nn.ModuleList(blocks)
|
268 |
+
|
269 |
+
def forward(self, x):
|
270 |
+
for block in self.blocks:
|
271 |
+
x = block(x)
|
272 |
+
|
273 |
+
return x
|
274 |
+
|
275 |
+
|
276 |
+
class Generator(nn.Module):
|
277 |
+
def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
|
278 |
+
super().__init__()
|
279 |
+
self.nf = nf
|
280 |
+
self.ch_mult = ch_mult
|
281 |
+
self.num_resolutions = len(self.ch_mult)
|
282 |
+
self.num_res_blocks = res_blocks
|
283 |
+
self.resolution = img_size
|
284 |
+
self.attn_resolutions = attn_resolutions
|
285 |
+
self.in_channels = emb_dim
|
286 |
+
self.out_channels = 3
|
287 |
+
block_in_ch = self.nf * self.ch_mult[-1]
|
288 |
+
curr_res = self.resolution // 2 ** (self.num_resolutions-1)
|
289 |
+
|
290 |
+
blocks = []
|
291 |
+
# initial conv
|
292 |
+
blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1))
|
293 |
+
|
294 |
+
# non-local attention block
|
295 |
+
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
296 |
+
blocks.append(AttnBlock(block_in_ch))
|
297 |
+
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
298 |
+
|
299 |
+
for i in reversed(range(self.num_resolutions)):
|
300 |
+
block_out_ch = self.nf * self.ch_mult[i]
|
301 |
+
|
302 |
+
for _ in range(self.num_res_blocks):
|
303 |
+
blocks.append(ResBlock(block_in_ch, block_out_ch))
|
304 |
+
block_in_ch = block_out_ch
|
305 |
+
|
306 |
+
if curr_res in self.attn_resolutions:
|
307 |
+
blocks.append(AttnBlock(block_in_ch))
|
308 |
+
|
309 |
+
if i != 0:
|
310 |
+
blocks.append(Upsample(block_in_ch))
|
311 |
+
curr_res = curr_res * 2
|
312 |
+
|
313 |
+
blocks.append(normalize(block_in_ch))
|
314 |
+
blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
|
315 |
+
|
316 |
+
self.blocks = nn.ModuleList(blocks)
|
317 |
+
|
318 |
+
|
319 |
+
def forward(self, x):
|
320 |
+
for block in self.blocks:
|
321 |
+
x = block(x)
|
322 |
+
|
323 |
+
return x
|
324 |
+
|
325 |
+
|
326 |
+
@ARCH_REGISTRY.register()
|
327 |
+
class VQAutoEncoder(nn.Module):
|
328 |
+
def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256,
|
329 |
+
beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
|
330 |
+
super().__init__()
|
331 |
+
logger = get_root_logger()
|
332 |
+
self.in_channels = 3
|
333 |
+
self.nf = nf
|
334 |
+
self.n_blocks = res_blocks
|
335 |
+
self.codebook_size = codebook_size
|
336 |
+
self.embed_dim = emb_dim
|
337 |
+
self.ch_mult = ch_mult
|
338 |
+
self.resolution = img_size
|
339 |
+
self.attn_resolutions = attn_resolutions
|
340 |
+
self.quantizer_type = quantizer
|
341 |
+
self.encoder = Encoder(
|
342 |
+
self.in_channels,
|
343 |
+
self.nf,
|
344 |
+
self.embed_dim,
|
345 |
+
self.ch_mult,
|
346 |
+
self.n_blocks,
|
347 |
+
self.resolution,
|
348 |
+
self.attn_resolutions
|
349 |
+
)
|
350 |
+
if self.quantizer_type == "nearest":
|
351 |
+
self.beta = beta #0.25
|
352 |
+
self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta)
|
353 |
+
elif self.quantizer_type == "gumbel":
|
354 |
+
self.gumbel_num_hiddens = emb_dim
|
355 |
+
self.straight_through = gumbel_straight_through
|
356 |
+
self.kl_weight = gumbel_kl_weight
|
357 |
+
self.quantize = GumbelQuantizer(
|
358 |
+
self.codebook_size,
|
359 |
+
self.embed_dim,
|
360 |
+
self.gumbel_num_hiddens,
|
361 |
+
self.straight_through,
|
362 |
+
self.kl_weight
|
363 |
+
)
|
364 |
+
self.generator = Generator(
|
365 |
+
self.nf,
|
366 |
+
self.embed_dim,
|
367 |
+
self.ch_mult,
|
368 |
+
self.n_blocks,
|
369 |
+
self.resolution,
|
370 |
+
self.attn_resolutions
|
371 |
+
)
|
372 |
+
|
373 |
+
if model_path is not None:
|
374 |
+
chkpt = torch.load(model_path, map_location='cpu')
|
375 |
+
if 'params_ema' in chkpt:
|
376 |
+
self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema'])
|
377 |
+
logger.info(f'vqgan is loaded from: {model_path} [params_ema]')
|
378 |
+
elif 'params' in chkpt:
|
379 |
+
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
|
380 |
+
logger.info(f'vqgan is loaded from: {model_path} [params]')
|
381 |
+
else:
|
382 |
+
raise ValueError(f'Wrong params!')
|
383 |
+
|
384 |
+
|
385 |
+
def forward(self, x):
|
386 |
+
x = self.encoder(x)
|
387 |
+
quant, codebook_loss, quant_stats = self.quantize(x)
|
388 |
+
x = self.generator(quant)
|
389 |
+
return x, codebook_loss, quant_stats
|
390 |
+
|
391 |
+
|
392 |
+
|
393 |
+
# patch based discriminator
|
394 |
+
@ARCH_REGISTRY.register()
|
395 |
+
class VQGANDiscriminator(nn.Module):
|
396 |
+
def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
|
397 |
+
super().__init__()
|
398 |
+
|
399 |
+
layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
|
400 |
+
ndf_mult = 1
|
401 |
+
ndf_mult_prev = 1
|
402 |
+
for n in range(1, n_layers): # gradually increase the number of filters
|
403 |
+
ndf_mult_prev = ndf_mult
|
404 |
+
ndf_mult = min(2 ** n, 8)
|
405 |
+
layers += [
|
406 |
+
nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False),
|
407 |
+
nn.BatchNorm2d(ndf * ndf_mult),
|
408 |
+
nn.LeakyReLU(0.2, True)
|
409 |
+
]
|
410 |
+
|
411 |
+
ndf_mult_prev = ndf_mult
|
412 |
+
ndf_mult = min(2 ** n_layers, 8)
|
413 |
+
|
414 |
+
layers += [
|
415 |
+
nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False),
|
416 |
+
nn.BatchNorm2d(ndf * ndf_mult),
|
417 |
+
nn.LeakyReLU(0.2, True)
|
418 |
+
]
|
419 |
+
|
420 |
+
layers += [
|
421 |
+
nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map
|
422 |
+
self.main = nn.Sequential(*layers)
|
423 |
+
|
424 |
+
if model_path is not None:
|
425 |
+
chkpt = torch.load(model_path, map_location='cpu')
|
426 |
+
if 'params_d' in chkpt:
|
427 |
+
self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d'])
|
428 |
+
elif 'params' in chkpt:
|
429 |
+
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
|
430 |
+
else:
|
431 |
+
raise ValueError(f'Wrong params!')
|
432 |
+
|
433 |
+
def forward(self, x):
|
434 |
+
return self.main(x)
|
blissful_tuner/codeformer/basicsr/data/__init__.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
import numpy as np
|
3 |
+
import random
|
4 |
+
import torch
|
5 |
+
import torch.utils.data
|
6 |
+
from copy import deepcopy
|
7 |
+
from functools import partial
|
8 |
+
from os import path as osp
|
9 |
+
|
10 |
+
from codeformer.basicsr.data.prefetch_dataloader import PrefetchDataLoader
|
11 |
+
from codeformer.basicsr.utils import get_root_logger, scandir
|
12 |
+
from codeformer.basicsr.utils.dist_util import get_dist_info
|
13 |
+
from codeformer.basicsr.utils.registry import DATASET_REGISTRY
|
14 |
+
|
15 |
+
__all__ = ['build_dataset', 'build_dataloader']
|
16 |
+
|
17 |
+
# automatically scan and import dataset modules for registry
|
18 |
+
# scan all the files under the data folder with '_dataset' in file names
|
19 |
+
data_folder = osp.dirname(osp.abspath(__file__))
|
20 |
+
dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
|
21 |
+
# import all the dataset modules
|
22 |
+
_dataset_modules = [importlib.import_module(f'codeformer.basicsr.data.{file_name}') for file_name in dataset_filenames]
|
23 |
+
|
24 |
+
|
25 |
+
def build_dataset(dataset_opt):
|
26 |
+
"""Build dataset from options.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
dataset_opt (dict): Configuration for dataset. It must constain:
|
30 |
+
name (str): Dataset name.
|
31 |
+
type (str): Dataset type.
|
32 |
+
"""
|
33 |
+
dataset_opt = deepcopy(dataset_opt)
|
34 |
+
dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt)
|
35 |
+
logger = get_root_logger()
|
36 |
+
logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} ' 'is built.')
|
37 |
+
return dataset
|
38 |
+
|
39 |
+
|
40 |
+
def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None):
|
41 |
+
"""Build dataloader.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
dataset (torch.utils.data.Dataset): Dataset.
|
45 |
+
dataset_opt (dict): Dataset options. It contains the following keys:
|
46 |
+
phase (str): 'train' or 'val'.
|
47 |
+
num_worker_per_gpu (int): Number of workers for each GPU.
|
48 |
+
batch_size_per_gpu (int): Training batch size for each GPU.
|
49 |
+
num_gpu (int): Number of GPUs. Used only in the train phase.
|
50 |
+
Default: 1.
|
51 |
+
dist (bool): Whether in distributed training. Used only in the train
|
52 |
+
phase. Default: False.
|
53 |
+
sampler (torch.utils.data.sampler): Data sampler. Default: None.
|
54 |
+
seed (int | None): Seed. Default: None
|
55 |
+
"""
|
56 |
+
phase = dataset_opt['phase']
|
57 |
+
rank, _ = get_dist_info()
|
58 |
+
if phase == 'train':
|
59 |
+
if dist: # distributed training
|
60 |
+
batch_size = dataset_opt['batch_size_per_gpu']
|
61 |
+
num_workers = dataset_opt['num_worker_per_gpu']
|
62 |
+
else: # non-distributed training
|
63 |
+
multiplier = 1 if num_gpu == 0 else num_gpu
|
64 |
+
batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
|
65 |
+
num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
|
66 |
+
dataloader_args = dict(
|
67 |
+
dataset=dataset,
|
68 |
+
batch_size=batch_size,
|
69 |
+
shuffle=False,
|
70 |
+
num_workers=num_workers,
|
71 |
+
sampler=sampler,
|
72 |
+
drop_last=True)
|
73 |
+
if sampler is None:
|
74 |
+
dataloader_args['shuffle'] = True
|
75 |
+
dataloader_args['worker_init_fn'] = partial(
|
76 |
+
worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None
|
77 |
+
elif phase in ['val', 'test']: # validation
|
78 |
+
dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
|
79 |
+
else:
|
80 |
+
raise ValueError(f'Wrong dataset phase: {phase}. ' "Supported ones are 'train', 'val' and 'test'.")
|
81 |
+
|
82 |
+
dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
|
83 |
+
|
84 |
+
prefetch_mode = dataset_opt.get('prefetch_mode')
|
85 |
+
if prefetch_mode == 'cpu': # CPUPrefetcher
|
86 |
+
num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
|
87 |
+
logger = get_root_logger()
|
88 |
+
logger.info(f'Use {prefetch_mode} prefetch dataloader: ' f'num_prefetch_queue = {num_prefetch_queue}')
|
89 |
+
return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args)
|
90 |
+
else:
|
91 |
+
# prefetch_mode=None: Normal dataloader
|
92 |
+
# prefetch_mode='cuda': dataloader for CUDAPrefetcher
|
93 |
+
return torch.utils.data.DataLoader(**dataloader_args)
|
94 |
+
|
95 |
+
|
96 |
+
def worker_init_fn(worker_id, num_workers, rank, seed):
|
97 |
+
# Set the worker seed to num_workers * rank + worker_id + seed
|
98 |
+
worker_seed = num_workers * rank + worker_id + seed
|
99 |
+
np.random.seed(worker_seed)
|
100 |
+
random.seed(worker_seed)
|
blissful_tuner/codeformer/basicsr/data/data_sampler.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch.utils.data.sampler import Sampler
|
4 |
+
|
5 |
+
|
6 |
+
class EnlargedSampler(Sampler):
|
7 |
+
"""Sampler that restricts data loading to a subset of the dataset.
|
8 |
+
|
9 |
+
Modified from torch.utils.data.distributed.DistributedSampler
|
10 |
+
Support enlarging the dataset for iteration-based training, for saving
|
11 |
+
time when restart the dataloader after each epoch
|
12 |
+
|
13 |
+
Args:
|
14 |
+
dataset (torch.utils.data.Dataset): Dataset used for sampling.
|
15 |
+
num_replicas (int | None): Number of processes participating in
|
16 |
+
the training. It is usually the world_size.
|
17 |
+
rank (int | None): Rank of the current process within num_replicas.
|
18 |
+
ratio (int): Enlarging ratio. Default: 1.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, dataset, num_replicas, rank, ratio=1):
|
22 |
+
self.dataset = dataset
|
23 |
+
self.num_replicas = num_replicas
|
24 |
+
self.rank = rank
|
25 |
+
self.epoch = 0
|
26 |
+
self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas)
|
27 |
+
self.total_size = self.num_samples * self.num_replicas
|
28 |
+
|
29 |
+
def __iter__(self):
|
30 |
+
# deterministically shuffle based on epoch
|
31 |
+
g = torch.Generator()
|
32 |
+
g.manual_seed(self.epoch)
|
33 |
+
indices = torch.randperm(self.total_size, generator=g).tolist()
|
34 |
+
|
35 |
+
dataset_size = len(self.dataset)
|
36 |
+
indices = [v % dataset_size for v in indices]
|
37 |
+
|
38 |
+
# subsample
|
39 |
+
indices = indices[self.rank:self.total_size:self.num_replicas]
|
40 |
+
assert len(indices) == self.num_samples
|
41 |
+
|
42 |
+
return iter(indices)
|
43 |
+
|
44 |
+
def __len__(self):
|
45 |
+
return self.num_samples
|
46 |
+
|
47 |
+
def set_epoch(self, epoch):
|
48 |
+
self.epoch = epoch
|
blissful_tuner/codeformer/basicsr/data/data_util.py
ADDED
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from os import path as osp
|
6 |
+
from PIL import Image, ImageDraw
|
7 |
+
from torch.nn import functional as F
|
8 |
+
|
9 |
+
from codeformer.basicsr.data.transforms import mod_crop
|
10 |
+
from codeformer.basicsr.utils import img2tensor, scandir
|
11 |
+
|
12 |
+
|
13 |
+
def read_img_seq(path, require_mod_crop=False, scale=1):
|
14 |
+
"""Read a sequence of images from a given folder path.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
path (list[str] | str): List of image paths or image folder path.
|
18 |
+
require_mod_crop (bool): Require mod crop for each image.
|
19 |
+
Default: False.
|
20 |
+
scale (int): Scale factor for mod_crop. Default: 1.
|
21 |
+
|
22 |
+
Returns:
|
23 |
+
Tensor: size (t, c, h, w), RGB, [0, 1].
|
24 |
+
"""
|
25 |
+
if isinstance(path, list):
|
26 |
+
img_paths = path
|
27 |
+
else:
|
28 |
+
img_paths = sorted(list(scandir(path, full_path=True)))
|
29 |
+
imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
|
30 |
+
if require_mod_crop:
|
31 |
+
imgs = [mod_crop(img, scale) for img in imgs]
|
32 |
+
imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
|
33 |
+
imgs = torch.stack(imgs, dim=0)
|
34 |
+
return imgs
|
35 |
+
|
36 |
+
|
37 |
+
def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='reflection'):
|
38 |
+
"""Generate an index list for reading `num_frames` frames from a sequence
|
39 |
+
of images.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
crt_idx (int): Current center index.
|
43 |
+
max_frame_num (int): Max number of the sequence of images (from 1).
|
44 |
+
num_frames (int): Reading num_frames frames.
|
45 |
+
padding (str): Padding mode, one of
|
46 |
+
'replicate' | 'reflection' | 'reflection_circle' | 'circle'
|
47 |
+
Examples: current_idx = 0, num_frames = 5
|
48 |
+
The generated frame indices under different padding mode:
|
49 |
+
replicate: [0, 0, 0, 1, 2]
|
50 |
+
reflection: [2, 1, 0, 1, 2]
|
51 |
+
reflection_circle: [4, 3, 0, 1, 2]
|
52 |
+
circle: [3, 4, 0, 1, 2]
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
list[int]: A list of indices.
|
56 |
+
"""
|
57 |
+
assert num_frames % 2 == 1, 'num_frames should be an odd number.'
|
58 |
+
assert padding in ('replicate', 'reflection', 'reflection_circle', 'circle'), f'Wrong padding mode: {padding}.'
|
59 |
+
|
60 |
+
max_frame_num = max_frame_num - 1 # start from 0
|
61 |
+
num_pad = num_frames // 2
|
62 |
+
|
63 |
+
indices = []
|
64 |
+
for i in range(crt_idx - num_pad, crt_idx + num_pad + 1):
|
65 |
+
if i < 0:
|
66 |
+
if padding == 'replicate':
|
67 |
+
pad_idx = 0
|
68 |
+
elif padding == 'reflection':
|
69 |
+
pad_idx = -i
|
70 |
+
elif padding == 'reflection_circle':
|
71 |
+
pad_idx = crt_idx + num_pad - i
|
72 |
+
else:
|
73 |
+
pad_idx = num_frames + i
|
74 |
+
elif i > max_frame_num:
|
75 |
+
if padding == 'replicate':
|
76 |
+
pad_idx = max_frame_num
|
77 |
+
elif padding == 'reflection':
|
78 |
+
pad_idx = max_frame_num * 2 - i
|
79 |
+
elif padding == 'reflection_circle':
|
80 |
+
pad_idx = (crt_idx - num_pad) - (i - max_frame_num)
|
81 |
+
else:
|
82 |
+
pad_idx = i - num_frames
|
83 |
+
else:
|
84 |
+
pad_idx = i
|
85 |
+
indices.append(pad_idx)
|
86 |
+
return indices
|
87 |
+
|
88 |
+
|
89 |
+
def paired_paths_from_lmdb(folders, keys):
|
90 |
+
"""Generate paired paths from lmdb files.
|
91 |
+
|
92 |
+
Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is:
|
93 |
+
|
94 |
+
lq.lmdb
|
95 |
+
├── data.mdb
|
96 |
+
├── lock.mdb
|
97 |
+
├── meta_info.txt
|
98 |
+
|
99 |
+
The data.mdb and lock.mdb are standard lmdb files and you can refer to
|
100 |
+
https://lmdb.readthedocs.io/en/release/ for more details.
|
101 |
+
|
102 |
+
The meta_info.txt is a specified txt file to record the meta information
|
103 |
+
of our datasets. It will be automatically created when preparing
|
104 |
+
datasets by our provided dataset tools.
|
105 |
+
Each line in the txt file records
|
106 |
+
1)image name (with extension),
|
107 |
+
2)image shape,
|
108 |
+
3)compression level, separated by a white space.
|
109 |
+
Example: `baboon.png (120,125,3) 1`
|
110 |
+
|
111 |
+
We use the image name without extension as the lmdb key.
|
112 |
+
Note that we use the same key for the corresponding lq and gt images.
|
113 |
+
|
114 |
+
Args:
|
115 |
+
folders (list[str]): A list of folder path. The order of list should
|
116 |
+
be [input_folder, gt_folder].
|
117 |
+
keys (list[str]): A list of keys identifying folders. The order should
|
118 |
+
be in consistent with folders, e.g., ['lq', 'gt'].
|
119 |
+
Note that this key is different from lmdb keys.
|
120 |
+
|
121 |
+
Returns:
|
122 |
+
list[str]: Returned path list.
|
123 |
+
"""
|
124 |
+
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
|
125 |
+
f'But got {len(folders)}')
|
126 |
+
assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
|
127 |
+
input_folder, gt_folder = folders
|
128 |
+
input_key, gt_key = keys
|
129 |
+
|
130 |
+
if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')):
|
131 |
+
raise ValueError(f'{input_key} folder and {gt_key} folder should both in lmdb '
|
132 |
+
f'formats. But received {input_key}: {input_folder}; '
|
133 |
+
f'{gt_key}: {gt_folder}')
|
134 |
+
# ensure that the two meta_info files are the same
|
135 |
+
with open(osp.join(input_folder, 'meta_info.txt')) as fin:
|
136 |
+
input_lmdb_keys = [line.split('.')[0] for line in fin]
|
137 |
+
with open(osp.join(gt_folder, 'meta_info.txt')) as fin:
|
138 |
+
gt_lmdb_keys = [line.split('.')[0] for line in fin]
|
139 |
+
if set(input_lmdb_keys) != set(gt_lmdb_keys):
|
140 |
+
raise ValueError(f'Keys in {input_key}_folder and {gt_key}_folder are different.')
|
141 |
+
else:
|
142 |
+
paths = []
|
143 |
+
for lmdb_key in sorted(input_lmdb_keys):
|
144 |
+
paths.append(dict([(f'{input_key}_path', lmdb_key), (f'{gt_key}_path', lmdb_key)]))
|
145 |
+
return paths
|
146 |
+
|
147 |
+
|
148 |
+
def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmpl):
|
149 |
+
"""Generate paired paths from an meta information file.
|
150 |
+
|
151 |
+
Each line in the meta information file contains the image names and
|
152 |
+
image shape (usually for gt), separated by a white space.
|
153 |
+
|
154 |
+
Example of an meta information file:
|
155 |
+
```
|
156 |
+
0001_s001.png (480,480,3)
|
157 |
+
0001_s002.png (480,480,3)
|
158 |
+
```
|
159 |
+
|
160 |
+
Args:
|
161 |
+
folders (list[str]): A list of folder path. The order of list should
|
162 |
+
be [input_folder, gt_folder].
|
163 |
+
keys (list[str]): A list of keys identifying folders. The order should
|
164 |
+
be in consistent with folders, e.g., ['lq', 'gt'].
|
165 |
+
meta_info_file (str): Path to the meta information file.
|
166 |
+
filename_tmpl (str): Template for each filename. Note that the
|
167 |
+
template excludes the file extension. Usually the filename_tmpl is
|
168 |
+
for files in the input folder.
|
169 |
+
|
170 |
+
Returns:
|
171 |
+
list[str]: Returned path list.
|
172 |
+
"""
|
173 |
+
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
|
174 |
+
f'But got {len(folders)}')
|
175 |
+
assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
|
176 |
+
input_folder, gt_folder = folders
|
177 |
+
input_key, gt_key = keys
|
178 |
+
|
179 |
+
with open(meta_info_file, 'r') as fin:
|
180 |
+
gt_names = [line.split(' ')[0] for line in fin]
|
181 |
+
|
182 |
+
paths = []
|
183 |
+
for gt_name in gt_names:
|
184 |
+
basename, ext = osp.splitext(osp.basename(gt_name))
|
185 |
+
input_name = f'{filename_tmpl.format(basename)}{ext}'
|
186 |
+
input_path = osp.join(input_folder, input_name)
|
187 |
+
gt_path = osp.join(gt_folder, gt_name)
|
188 |
+
paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
|
189 |
+
return paths
|
190 |
+
|
191 |
+
|
192 |
+
def paired_paths_from_folder(folders, keys, filename_tmpl):
|
193 |
+
"""Generate paired paths from folders.
|
194 |
+
|
195 |
+
Args:
|
196 |
+
folders (list[str]): A list of folder path. The order of list should
|
197 |
+
be [input_folder, gt_folder].
|
198 |
+
keys (list[str]): A list of keys identifying folders. The order should
|
199 |
+
be in consistent with folders, e.g., ['lq', 'gt'].
|
200 |
+
filename_tmpl (str): Template for each filename. Note that the
|
201 |
+
template excludes the file extension. Usually the filename_tmpl is
|
202 |
+
for files in the input folder.
|
203 |
+
|
204 |
+
Returns:
|
205 |
+
list[str]: Returned path list.
|
206 |
+
"""
|
207 |
+
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
|
208 |
+
f'But got {len(folders)}')
|
209 |
+
assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
|
210 |
+
input_folder, gt_folder = folders
|
211 |
+
input_key, gt_key = keys
|
212 |
+
|
213 |
+
input_paths = list(scandir(input_folder))
|
214 |
+
gt_paths = list(scandir(gt_folder))
|
215 |
+
assert len(input_paths) == len(gt_paths), (f'{input_key} and {gt_key} datasets have different number of images: '
|
216 |
+
f'{len(input_paths)}, {len(gt_paths)}.')
|
217 |
+
paths = []
|
218 |
+
for gt_path in gt_paths:
|
219 |
+
basename, ext = osp.splitext(osp.basename(gt_path))
|
220 |
+
input_name = f'{filename_tmpl.format(basename)}{ext}'
|
221 |
+
input_path = osp.join(input_folder, input_name)
|
222 |
+
assert input_name in input_paths, (f'{input_name} is not in ' f'{input_key}_paths.')
|
223 |
+
gt_path = osp.join(gt_folder, gt_path)
|
224 |
+
paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
|
225 |
+
return paths
|
226 |
+
|
227 |
+
|
228 |
+
def paths_from_folder(folder):
|
229 |
+
"""Generate paths from folder.
|
230 |
+
|
231 |
+
Args:
|
232 |
+
folder (str): Folder path.
|
233 |
+
|
234 |
+
Returns:
|
235 |
+
list[str]: Returned path list.
|
236 |
+
"""
|
237 |
+
|
238 |
+
paths = list(scandir(folder))
|
239 |
+
paths = [osp.join(folder, path) for path in paths]
|
240 |
+
return paths
|
241 |
+
|
242 |
+
|
243 |
+
def paths_from_lmdb(folder):
|
244 |
+
"""Generate paths from lmdb.
|
245 |
+
|
246 |
+
Args:
|
247 |
+
folder (str): Folder path.
|
248 |
+
|
249 |
+
Returns:
|
250 |
+
list[str]: Returned path list.
|
251 |
+
"""
|
252 |
+
if not folder.endswith('.lmdb'):
|
253 |
+
raise ValueError(f'Folder {folder}folder should in lmdb format.')
|
254 |
+
with open(osp.join(folder, 'meta_info.txt')) as fin:
|
255 |
+
paths = [line.split('.')[0] for line in fin]
|
256 |
+
return paths
|
257 |
+
|
258 |
+
|
259 |
+
def generate_gaussian_kernel(kernel_size=13, sigma=1.6):
|
260 |
+
"""Generate Gaussian kernel used in `duf_downsample`.
|
261 |
+
|
262 |
+
Args:
|
263 |
+
kernel_size (int): Kernel size. Default: 13.
|
264 |
+
sigma (float): Sigma of the Gaussian kernel. Default: 1.6.
|
265 |
+
|
266 |
+
Returns:
|
267 |
+
np.array: The Gaussian kernel.
|
268 |
+
"""
|
269 |
+
from scipy.ndimage import filters as filters
|
270 |
+
kernel = np.zeros((kernel_size, kernel_size))
|
271 |
+
# set element at the middle to one, a dirac delta
|
272 |
+
kernel[kernel_size // 2, kernel_size // 2] = 1
|
273 |
+
# gaussian-smooth the dirac, resulting in a gaussian filter
|
274 |
+
return filters.gaussian_filter(kernel, sigma)
|
275 |
+
|
276 |
+
|
277 |
+
def duf_downsample(x, kernel_size=13, scale=4):
|
278 |
+
"""Downsamping with Gaussian kernel used in the DUF official code.
|
279 |
+
|
280 |
+
Args:
|
281 |
+
x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w).
|
282 |
+
kernel_size (int): Kernel size. Default: 13.
|
283 |
+
scale (int): Downsampling factor. Supported scale: (2, 3, 4).
|
284 |
+
Default: 4.
|
285 |
+
|
286 |
+
Returns:
|
287 |
+
Tensor: DUF downsampled frames.
|
288 |
+
"""
|
289 |
+
assert scale in (2, 3, 4), f'Only support scale (2, 3, 4), but got {scale}.'
|
290 |
+
|
291 |
+
squeeze_flag = False
|
292 |
+
if x.ndim == 4:
|
293 |
+
squeeze_flag = True
|
294 |
+
x = x.unsqueeze(0)
|
295 |
+
b, t, c, h, w = x.size()
|
296 |
+
x = x.view(-1, 1, h, w)
|
297 |
+
pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2
|
298 |
+
x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), 'reflect')
|
299 |
+
|
300 |
+
gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale)
|
301 |
+
gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(0).unsqueeze(0)
|
302 |
+
x = F.conv2d(x, gaussian_filter, stride=scale)
|
303 |
+
x = x[:, :, 2:-2, 2:-2]
|
304 |
+
x = x.view(b, t, c, x.size(2), x.size(3))
|
305 |
+
if squeeze_flag:
|
306 |
+
x = x.squeeze(0)
|
307 |
+
return x
|
308 |
+
|
309 |
+
|
310 |
+
def brush_stroke_mask(img, color=(255,255,255)):
|
311 |
+
min_num_vertex = 8
|
312 |
+
max_num_vertex = 28
|
313 |
+
mean_angle = 2*math.pi / 5
|
314 |
+
angle_range = 2*math.pi / 12
|
315 |
+
# training large mask ratio (training setting)
|
316 |
+
min_width = 30
|
317 |
+
max_width = 70
|
318 |
+
# very large mask ratio (test setting and refine after 200k)
|
319 |
+
# min_width = 80
|
320 |
+
# max_width = 120
|
321 |
+
def generate_mask(H, W, img=None):
|
322 |
+
average_radius = math.sqrt(H*H+W*W) / 8
|
323 |
+
mask = Image.new('RGB', (W, H), 0)
|
324 |
+
if img is not None: mask = img # Image.fromarray(img)
|
325 |
+
|
326 |
+
for _ in range(np.random.randint(1, 4)):
|
327 |
+
num_vertex = np.random.randint(min_num_vertex, max_num_vertex)
|
328 |
+
angle_min = mean_angle - np.random.uniform(0, angle_range)
|
329 |
+
angle_max = mean_angle + np.random.uniform(0, angle_range)
|
330 |
+
angles = []
|
331 |
+
vertex = []
|
332 |
+
for i in range(num_vertex):
|
333 |
+
if i % 2 == 0:
|
334 |
+
angles.append(2*math.pi - np.random.uniform(angle_min, angle_max))
|
335 |
+
else:
|
336 |
+
angles.append(np.random.uniform(angle_min, angle_max))
|
337 |
+
|
338 |
+
h, w = mask.size
|
339 |
+
vertex.append((int(np.random.randint(0, w)), int(np.random.randint(0, h))))
|
340 |
+
for i in range(num_vertex):
|
341 |
+
r = np.clip(
|
342 |
+
np.random.normal(loc=average_radius, scale=average_radius//2),
|
343 |
+
0, 2*average_radius)
|
344 |
+
new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w)
|
345 |
+
new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h)
|
346 |
+
vertex.append((int(new_x), int(new_y)))
|
347 |
+
|
348 |
+
draw = ImageDraw.Draw(mask)
|
349 |
+
width = int(np.random.uniform(min_width, max_width))
|
350 |
+
draw.line(vertex, fill=color, width=width)
|
351 |
+
for v in vertex:
|
352 |
+
draw.ellipse((v[0] - width//2,
|
353 |
+
v[1] - width//2,
|
354 |
+
v[0] + width//2,
|
355 |
+
v[1] + width//2),
|
356 |
+
fill=color)
|
357 |
+
|
358 |
+
return mask
|
359 |
+
|
360 |
+
width, height = img.size
|
361 |
+
mask = generate_mask(height, width, img)
|
362 |
+
return mask
|
363 |
+
|
364 |
+
|
365 |
+
def random_ff_mask(shape, max_angle = 10, max_len = 100, max_width = 70, times = 10):
|
366 |
+
"""Generate a random free form mask with configuration.
|
367 |
+
Args:
|
368 |
+
config: Config should have configuration including IMG_SHAPES,
|
369 |
+
VERTICAL_MARGIN, HEIGHT, HORIZONTAL_MARGIN, WIDTH.
|
370 |
+
Returns:
|
371 |
+
tuple: (top, left, height, width)
|
372 |
+
Link:
|
373 |
+
https://github.com/csqiangwen/DeepFillv2_Pytorch/blob/master/train_dataset.py
|
374 |
+
"""
|
375 |
+
height = shape[0]
|
376 |
+
width = shape[1]
|
377 |
+
mask = np.zeros((height, width), np.float32)
|
378 |
+
times = np.random.randint(times-5, times)
|
379 |
+
for i in range(times):
|
380 |
+
start_x = np.random.randint(width)
|
381 |
+
start_y = np.random.randint(height)
|
382 |
+
for j in range(1 + np.random.randint(5)):
|
383 |
+
angle = 0.01 + np.random.randint(max_angle)
|
384 |
+
if i % 2 == 0:
|
385 |
+
angle = 2 * 3.1415926 - angle
|
386 |
+
length = 10 + np.random.randint(max_len-20, max_len)
|
387 |
+
brush_w = 5 + np.random.randint(max_width-30, max_width)
|
388 |
+
end_x = (start_x + length * np.sin(angle)).astype(np.int32)
|
389 |
+
end_y = (start_y + length * np.cos(angle)).astype(np.int32)
|
390 |
+
cv2.line(mask, (start_y, start_x), (end_y, end_x), 1.0, brush_w)
|
391 |
+
start_x, start_y = end_x, end_y
|
392 |
+
return mask.astype(np.float32)
|
blissful_tuner/codeformer/basicsr/data/ffhq_blind_dataset.py
ADDED
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import math
|
3 |
+
import random
|
4 |
+
import numpy as np
|
5 |
+
import os.path as osp
|
6 |
+
from scipy.io import loadmat
|
7 |
+
from PIL import Image
|
8 |
+
import torch
|
9 |
+
import torch.utils.data as data
|
10 |
+
from torchvision.transforms.functional import (adjust_brightness, adjust_contrast,
|
11 |
+
adjust_hue, adjust_saturation, normalize)
|
12 |
+
from codeformer.basicsr.data import gaussian_kernels as gaussian_kernels
|
13 |
+
from codeformer.basicsr.data.transforms import augment
|
14 |
+
from codeformer.basicsr.data.data_util import paths_from_folder, brush_stroke_mask, random_ff_mask
|
15 |
+
from codeformer.basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
|
16 |
+
from codeformer.basicsr.utils.registry import DATASET_REGISTRY
|
17 |
+
|
18 |
+
@DATASET_REGISTRY.register()
|
19 |
+
class FFHQBlindDataset(data.Dataset):
|
20 |
+
|
21 |
+
def __init__(self, opt):
|
22 |
+
super(FFHQBlindDataset, self).__init__()
|
23 |
+
logger = get_root_logger()
|
24 |
+
self.opt = opt
|
25 |
+
# file client (io backend)
|
26 |
+
self.file_client = None
|
27 |
+
self.io_backend_opt = opt['io_backend']
|
28 |
+
|
29 |
+
self.gt_folder = opt['dataroot_gt']
|
30 |
+
self.gt_size = opt.get('gt_size', 512)
|
31 |
+
self.in_size = opt.get('in_size', 512)
|
32 |
+
assert self.gt_size >= self.in_size, 'Wrong setting.'
|
33 |
+
|
34 |
+
self.mean = opt.get('mean', [0.5, 0.5, 0.5])
|
35 |
+
self.std = opt.get('std', [0.5, 0.5, 0.5])
|
36 |
+
|
37 |
+
self.component_path = opt.get('component_path', None)
|
38 |
+
self.latent_gt_path = opt.get('latent_gt_path', None)
|
39 |
+
|
40 |
+
if self.component_path is not None:
|
41 |
+
self.crop_components = True
|
42 |
+
self.components_dict = torch.load(self.component_path)
|
43 |
+
self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1.4)
|
44 |
+
self.nose_enlarge_ratio = opt.get('nose_enlarge_ratio', 1.1)
|
45 |
+
self.mouth_enlarge_ratio = opt.get('mouth_enlarge_ratio', 1.3)
|
46 |
+
else:
|
47 |
+
self.crop_components = False
|
48 |
+
|
49 |
+
if self.latent_gt_path is not None:
|
50 |
+
self.load_latent_gt = True
|
51 |
+
self.latent_gt_dict = torch.load(self.latent_gt_path)
|
52 |
+
else:
|
53 |
+
self.load_latent_gt = False
|
54 |
+
|
55 |
+
if self.io_backend_opt['type'] == 'lmdb':
|
56 |
+
self.io_backend_opt['db_paths'] = self.gt_folder
|
57 |
+
if not self.gt_folder.endswith('.lmdb'):
|
58 |
+
raise ValueError("'dataroot_gt' should end with '.lmdb', "f'but received {self.gt_folder}')
|
59 |
+
with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
|
60 |
+
self.paths = [line.split('.')[0] for line in fin]
|
61 |
+
else:
|
62 |
+
self.paths = paths_from_folder(self.gt_folder)
|
63 |
+
|
64 |
+
# inpainting mask
|
65 |
+
self.gen_inpaint_mask = opt.get('gen_inpaint_mask', False)
|
66 |
+
if self.gen_inpaint_mask:
|
67 |
+
logger.info(f'generate mask ...')
|
68 |
+
# self.mask_max_angle = opt.get('mask_max_angle', 10)
|
69 |
+
# self.mask_max_len = opt.get('mask_max_len', 150)
|
70 |
+
# self.mask_max_width = opt.get('mask_max_width', 50)
|
71 |
+
# self.mask_draw_times = opt.get('mask_draw_times', 10)
|
72 |
+
# # print
|
73 |
+
# logger.info(f'mask_max_angle: {self.mask_max_angle}')
|
74 |
+
# logger.info(f'mask_max_len: {self.mask_max_len}')
|
75 |
+
# logger.info(f'mask_max_width: {self.mask_max_width}')
|
76 |
+
# logger.info(f'mask_draw_times: {self.mask_draw_times}')
|
77 |
+
|
78 |
+
# perform corrupt
|
79 |
+
self.use_corrupt = opt.get('use_corrupt', True)
|
80 |
+
self.use_motion_kernel = False
|
81 |
+
# self.use_motion_kernel = opt.get('use_motion_kernel', True)
|
82 |
+
|
83 |
+
if self.use_motion_kernel:
|
84 |
+
self.motion_kernel_prob = opt.get('motion_kernel_prob', 0.001)
|
85 |
+
motion_kernel_path = opt.get('motion_kernel_path', 'basicsr/data/motion-blur-kernels-32.pth')
|
86 |
+
self.motion_kernels = torch.load(motion_kernel_path)
|
87 |
+
|
88 |
+
if self.use_corrupt and not self.gen_inpaint_mask:
|
89 |
+
# degradation configurations
|
90 |
+
self.blur_kernel_size = opt['blur_kernel_size']
|
91 |
+
self.blur_sigma = opt['blur_sigma']
|
92 |
+
self.kernel_list = opt['kernel_list']
|
93 |
+
self.kernel_prob = opt['kernel_prob']
|
94 |
+
self.downsample_range = opt['downsample_range']
|
95 |
+
self.noise_range = opt['noise_range']
|
96 |
+
self.jpeg_range = opt['jpeg_range']
|
97 |
+
# print
|
98 |
+
logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, sigma: [{", ".join(map(str, self.blur_sigma))}]')
|
99 |
+
logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]')
|
100 |
+
logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
|
101 |
+
logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]')
|
102 |
+
|
103 |
+
# color jitter
|
104 |
+
self.color_jitter_prob = opt.get('color_jitter_prob', None)
|
105 |
+
self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob', None)
|
106 |
+
self.color_jitter_shift = opt.get('color_jitter_shift', 20)
|
107 |
+
if self.color_jitter_prob is not None:
|
108 |
+
logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, shift: {self.color_jitter_shift}')
|
109 |
+
|
110 |
+
# to gray
|
111 |
+
self.gray_prob = opt.get('gray_prob', 0.0)
|
112 |
+
if self.gray_prob is not None:
|
113 |
+
logger.info(f'Use random gray. Prob: {self.gray_prob}')
|
114 |
+
self.color_jitter_shift /= 255.
|
115 |
+
|
116 |
+
@staticmethod
|
117 |
+
def color_jitter(img, shift):
|
118 |
+
"""jitter color: randomly jitter the RGB values, in numpy formats"""
|
119 |
+
jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
|
120 |
+
img = img + jitter_val
|
121 |
+
img = np.clip(img, 0, 1)
|
122 |
+
return img
|
123 |
+
|
124 |
+
@staticmethod
|
125 |
+
def color_jitter_pt(img, brightness, contrast, saturation, hue):
|
126 |
+
"""jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats"""
|
127 |
+
fn_idx = torch.randperm(4)
|
128 |
+
for fn_id in fn_idx:
|
129 |
+
if fn_id == 0 and brightness is not None:
|
130 |
+
brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
|
131 |
+
img = adjust_brightness(img, brightness_factor)
|
132 |
+
|
133 |
+
if fn_id == 1 and contrast is not None:
|
134 |
+
contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
|
135 |
+
img = adjust_contrast(img, contrast_factor)
|
136 |
+
|
137 |
+
if fn_id == 2 and saturation is not None:
|
138 |
+
saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
|
139 |
+
img = adjust_saturation(img, saturation_factor)
|
140 |
+
|
141 |
+
if fn_id == 3 and hue is not None:
|
142 |
+
hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
|
143 |
+
img = adjust_hue(img, hue_factor)
|
144 |
+
return img
|
145 |
+
|
146 |
+
|
147 |
+
def get_component_locations(self, name, status):
|
148 |
+
components_bbox = self.components_dict[name]
|
149 |
+
if status[0]: # hflip
|
150 |
+
# exchange right and left eye
|
151 |
+
tmp = components_bbox['left_eye']
|
152 |
+
components_bbox['left_eye'] = components_bbox['right_eye']
|
153 |
+
components_bbox['right_eye'] = tmp
|
154 |
+
# modify the width coordinate
|
155 |
+
components_bbox['left_eye'][0] = self.gt_size - components_bbox['left_eye'][0]
|
156 |
+
components_bbox['right_eye'][0] = self.gt_size - components_bbox['right_eye'][0]
|
157 |
+
components_bbox['nose'][0] = self.gt_size - components_bbox['nose'][0]
|
158 |
+
components_bbox['mouth'][0] = self.gt_size - components_bbox['mouth'][0]
|
159 |
+
|
160 |
+
locations_gt = {}
|
161 |
+
locations_in = {}
|
162 |
+
for part in ['left_eye', 'right_eye', 'nose', 'mouth']:
|
163 |
+
mean = components_bbox[part][0:2]
|
164 |
+
half_len = components_bbox[part][2]
|
165 |
+
if 'eye' in part:
|
166 |
+
half_len *= self.eye_enlarge_ratio
|
167 |
+
elif part == 'nose':
|
168 |
+
half_len *= self.nose_enlarge_ratio
|
169 |
+
elif part == 'mouth':
|
170 |
+
half_len *= self.mouth_enlarge_ratio
|
171 |
+
loc = np.hstack((mean - half_len + 1, mean + half_len))
|
172 |
+
loc = torch.from_numpy(loc).float()
|
173 |
+
locations_gt[part] = loc
|
174 |
+
loc_in = loc/(self.gt_size//self.in_size)
|
175 |
+
locations_in[part] = loc_in
|
176 |
+
return locations_gt, locations_in
|
177 |
+
|
178 |
+
|
179 |
+
def __getitem__(self, index):
|
180 |
+
if self.file_client is None:
|
181 |
+
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
182 |
+
|
183 |
+
# load gt image
|
184 |
+
gt_path = self.paths[index]
|
185 |
+
name = osp.basename(gt_path)[:-4]
|
186 |
+
img_bytes = self.file_client.get(gt_path)
|
187 |
+
img_gt = imfrombytes(img_bytes, float32=True)
|
188 |
+
|
189 |
+
# random horizontal flip
|
190 |
+
img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True)
|
191 |
+
|
192 |
+
if self.load_latent_gt:
|
193 |
+
if status[0]:
|
194 |
+
latent_gt = self.latent_gt_dict['hflip'][name]
|
195 |
+
else:
|
196 |
+
latent_gt = self.latent_gt_dict['orig'][name]
|
197 |
+
|
198 |
+
if self.crop_components:
|
199 |
+
locations_gt, locations_in = self.get_component_locations(name, status)
|
200 |
+
|
201 |
+
# generate in image
|
202 |
+
img_in = img_gt
|
203 |
+
if self.use_corrupt and not self.gen_inpaint_mask:
|
204 |
+
# motion blur
|
205 |
+
if self.use_motion_kernel and random.random() < self.motion_kernel_prob:
|
206 |
+
m_i = random.randint(0,31)
|
207 |
+
k = self.motion_kernels[f'{m_i:02d}']
|
208 |
+
img_in = cv2.filter2D(img_in,-1,k)
|
209 |
+
|
210 |
+
# gaussian blur
|
211 |
+
kernel = gaussian_kernels.random_mixed_kernels(
|
212 |
+
self.kernel_list,
|
213 |
+
self.kernel_prob,
|
214 |
+
self.blur_kernel_size,
|
215 |
+
self.blur_sigma,
|
216 |
+
self.blur_sigma,
|
217 |
+
[-math.pi, math.pi],
|
218 |
+
noise_range=None)
|
219 |
+
img_in = cv2.filter2D(img_in, -1, kernel)
|
220 |
+
|
221 |
+
# downsample
|
222 |
+
scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1])
|
223 |
+
img_in = cv2.resize(img_in, (int(self.gt_size // scale), int(self.gt_size // scale)), interpolation=cv2.INTER_LINEAR)
|
224 |
+
|
225 |
+
# noise
|
226 |
+
if self.noise_range is not None:
|
227 |
+
noise_sigma = np.random.uniform(self.noise_range[0] / 255., self.noise_range[1] / 255.)
|
228 |
+
noise = np.float32(np.random.randn(*(img_in.shape))) * noise_sigma
|
229 |
+
img_in = img_in + noise
|
230 |
+
img_in = np.clip(img_in, 0, 1)
|
231 |
+
|
232 |
+
# jpeg
|
233 |
+
if self.jpeg_range is not None:
|
234 |
+
jpeg_p = np.random.uniform(self.jpeg_range[0], self.jpeg_range[1])
|
235 |
+
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), int(jpeg_p)]
|
236 |
+
_, encimg = cv2.imencode('.jpg', img_in * 255., encode_param)
|
237 |
+
img_in = np.float32(cv2.imdecode(encimg, 1)) / 255.
|
238 |
+
|
239 |
+
# resize to in_size
|
240 |
+
img_in = cv2.resize(img_in, (self.in_size, self.in_size), interpolation=cv2.INTER_LINEAR)
|
241 |
+
|
242 |
+
# if self.gen_inpaint_mask:
|
243 |
+
# inpaint_mask = random_ff_mask(shape=(self.gt_size,self.gt_size),
|
244 |
+
# max_angle = self.mask_max_angle, max_len = self.mask_max_len,
|
245 |
+
# max_width = self.mask_max_width, times = self.mask_draw_times)
|
246 |
+
# img_in = img_in * (1 - inpaint_mask.reshape(self.gt_size,self.gt_size,1)) + \
|
247 |
+
# 1.0 * inpaint_mask.reshape(self.gt_size,self.gt_size,1)
|
248 |
+
|
249 |
+
# inpaint_mask = torch.from_numpy(inpaint_mask).view(1,self.gt_size,self.gt_size)
|
250 |
+
|
251 |
+
if self.gen_inpaint_mask:
|
252 |
+
img_in = (img_in*255).astype('uint8')
|
253 |
+
img_in = brush_stroke_mask(Image.fromarray(img_in))
|
254 |
+
img_in = np.array(img_in) / 255.
|
255 |
+
|
256 |
+
# random color jitter (only for lq)
|
257 |
+
if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob):
|
258 |
+
img_in = self.color_jitter(img_in, self.color_jitter_shift)
|
259 |
+
# random to gray (only for lq)
|
260 |
+
if self.gray_prob and np.random.uniform() < self.gray_prob:
|
261 |
+
img_in = cv2.cvtColor(img_in, cv2.COLOR_BGR2GRAY)
|
262 |
+
img_in = np.tile(img_in[:, :, None], [1, 1, 3])
|
263 |
+
|
264 |
+
# BGR to RGB, HWC to CHW, numpy to tensor
|
265 |
+
img_in, img_gt = img2tensor([img_in, img_gt], bgr2rgb=True, float32=True)
|
266 |
+
|
267 |
+
# random color jitter (pytorch version) (only for lq)
|
268 |
+
if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob):
|
269 |
+
brightness = self.opt.get('brightness', (0.5, 1.5))
|
270 |
+
contrast = self.opt.get('contrast', (0.5, 1.5))
|
271 |
+
saturation = self.opt.get('saturation', (0, 1.5))
|
272 |
+
hue = self.opt.get('hue', (-0.1, 0.1))
|
273 |
+
img_in = self.color_jitter_pt(img_in, brightness, contrast, saturation, hue)
|
274 |
+
|
275 |
+
# round and clip
|
276 |
+
img_in = np.clip((img_in * 255.0).round(), 0, 255) / 255.
|
277 |
+
|
278 |
+
# Set vgg range_norm=True if use the normalization here
|
279 |
+
# normalize
|
280 |
+
normalize(img_in, self.mean, self.std, inplace=True)
|
281 |
+
normalize(img_gt, self.mean, self.std, inplace=True)
|
282 |
+
|
283 |
+
return_dict = {'in': img_in, 'gt': img_gt, 'gt_path': gt_path}
|
284 |
+
|
285 |
+
if self.crop_components:
|
286 |
+
return_dict['locations_in'] = locations_in
|
287 |
+
return_dict['locations_gt'] = locations_gt
|
288 |
+
|
289 |
+
if self.load_latent_gt:
|
290 |
+
return_dict['latent_gt'] = latent_gt
|
291 |
+
|
292 |
+
# if self.gen_inpaint_mask:
|
293 |
+
# return_dict['inpaint_mask'] = inpaint_mask
|
294 |
+
|
295 |
+
return return_dict
|
296 |
+
|
297 |
+
|
298 |
+
def __len__(self):
|
299 |
+
return len(self.paths)
|
blissful_tuner/codeformer/basicsr/data/ffhq_blind_joint_dataset.py
ADDED
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import math
|
3 |
+
import random
|
4 |
+
import numpy as np
|
5 |
+
import os.path as osp
|
6 |
+
from scipy.io import loadmat
|
7 |
+
import torch
|
8 |
+
import torch.utils.data as data
|
9 |
+
from torchvision.transforms.functional import (adjust_brightness, adjust_contrast,
|
10 |
+
adjust_hue, adjust_saturation, normalize)
|
11 |
+
from codeformer.basicsr.data import gaussian_kernels as gaussian_kernels
|
12 |
+
from codeformer.basicsr.data.transforms import augment
|
13 |
+
from codeformer.basicsr.data.data_util import paths_from_folder
|
14 |
+
from codeformer.basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
|
15 |
+
from codeformer.basicsr.utils.registry import DATASET_REGISTRY
|
16 |
+
|
17 |
+
@DATASET_REGISTRY.register()
|
18 |
+
class FFHQBlindJointDataset(data.Dataset):
|
19 |
+
|
20 |
+
def __init__(self, opt):
|
21 |
+
super(FFHQBlindJointDataset, self).__init__()
|
22 |
+
logger = get_root_logger()
|
23 |
+
self.opt = opt
|
24 |
+
# file client (io backend)
|
25 |
+
self.file_client = None
|
26 |
+
self.io_backend_opt = opt['io_backend']
|
27 |
+
|
28 |
+
self.gt_folder = opt['dataroot_gt']
|
29 |
+
self.gt_size = opt.get('gt_size', 512)
|
30 |
+
self.in_size = opt.get('in_size', 512)
|
31 |
+
assert self.gt_size >= self.in_size, 'Wrong setting.'
|
32 |
+
|
33 |
+
self.mean = opt.get('mean', [0.5, 0.5, 0.5])
|
34 |
+
self.std = opt.get('std', [0.5, 0.5, 0.5])
|
35 |
+
|
36 |
+
self.component_path = opt.get('component_path', None)
|
37 |
+
self.latent_gt_path = opt.get('latent_gt_path', None)
|
38 |
+
|
39 |
+
if self.component_path is not None:
|
40 |
+
self.crop_components = True
|
41 |
+
self.components_dict = torch.load(self.component_path)
|
42 |
+
self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1.4)
|
43 |
+
self.nose_enlarge_ratio = opt.get('nose_enlarge_ratio', 1.1)
|
44 |
+
self.mouth_enlarge_ratio = opt.get('mouth_enlarge_ratio', 1.3)
|
45 |
+
else:
|
46 |
+
self.crop_components = False
|
47 |
+
|
48 |
+
if self.latent_gt_path is not None:
|
49 |
+
self.load_latent_gt = True
|
50 |
+
self.latent_gt_dict = torch.load(self.latent_gt_path)
|
51 |
+
else:
|
52 |
+
self.load_latent_gt = False
|
53 |
+
|
54 |
+
if self.io_backend_opt['type'] == 'lmdb':
|
55 |
+
self.io_backend_opt['db_paths'] = self.gt_folder
|
56 |
+
if not self.gt_folder.endswith('.lmdb'):
|
57 |
+
raise ValueError("'dataroot_gt' should end with '.lmdb', "f'but received {self.gt_folder}')
|
58 |
+
with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
|
59 |
+
self.paths = [line.split('.')[0] for line in fin]
|
60 |
+
else:
|
61 |
+
self.paths = paths_from_folder(self.gt_folder)
|
62 |
+
|
63 |
+
# perform corrupt
|
64 |
+
self.use_corrupt = opt.get('use_corrupt', True)
|
65 |
+
self.use_motion_kernel = False
|
66 |
+
# self.use_motion_kernel = opt.get('use_motion_kernel', True)
|
67 |
+
|
68 |
+
if self.use_motion_kernel:
|
69 |
+
self.motion_kernel_prob = opt.get('motion_kernel_prob', 0.001)
|
70 |
+
motion_kernel_path = opt.get('motion_kernel_path', 'basicsr/data/motion-blur-kernels-32.pth')
|
71 |
+
self.motion_kernels = torch.load(motion_kernel_path)
|
72 |
+
|
73 |
+
if self.use_corrupt:
|
74 |
+
# degradation configurations
|
75 |
+
self.blur_kernel_size = self.opt['blur_kernel_size']
|
76 |
+
self.kernel_list = self.opt['kernel_list']
|
77 |
+
self.kernel_prob = self.opt['kernel_prob']
|
78 |
+
# Small degradation
|
79 |
+
self.blur_sigma = self.opt['blur_sigma']
|
80 |
+
self.downsample_range = self.opt['downsample_range']
|
81 |
+
self.noise_range = self.opt['noise_range']
|
82 |
+
self.jpeg_range = self.opt['jpeg_range']
|
83 |
+
# Large degradation
|
84 |
+
self.blur_sigma_large = self.opt['blur_sigma_large']
|
85 |
+
self.downsample_range_large = self.opt['downsample_range_large']
|
86 |
+
self.noise_range_large = self.opt['noise_range_large']
|
87 |
+
self.jpeg_range_large = self.opt['jpeg_range_large']
|
88 |
+
|
89 |
+
# print
|
90 |
+
logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, sigma: [{", ".join(map(str, self.blur_sigma))}]')
|
91 |
+
logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]')
|
92 |
+
logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
|
93 |
+
logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]')
|
94 |
+
|
95 |
+
# color jitter
|
96 |
+
self.color_jitter_prob = opt.get('color_jitter_prob', None)
|
97 |
+
self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob', None)
|
98 |
+
self.color_jitter_shift = opt.get('color_jitter_shift', 20)
|
99 |
+
if self.color_jitter_prob is not None:
|
100 |
+
logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, shift: {self.color_jitter_shift}')
|
101 |
+
|
102 |
+
# to gray
|
103 |
+
self.gray_prob = opt.get('gray_prob', 0.0)
|
104 |
+
if self.gray_prob is not None:
|
105 |
+
logger.info(f'Use random gray. Prob: {self.gray_prob}')
|
106 |
+
self.color_jitter_shift /= 255.
|
107 |
+
|
108 |
+
@staticmethod
|
109 |
+
def color_jitter(img, shift):
|
110 |
+
"""jitter color: randomly jitter the RGB values, in numpy formats"""
|
111 |
+
jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
|
112 |
+
img = img + jitter_val
|
113 |
+
img = np.clip(img, 0, 1)
|
114 |
+
return img
|
115 |
+
|
116 |
+
@staticmethod
|
117 |
+
def color_jitter_pt(img, brightness, contrast, saturation, hue):
|
118 |
+
"""jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats"""
|
119 |
+
fn_idx = torch.randperm(4)
|
120 |
+
for fn_id in fn_idx:
|
121 |
+
if fn_id == 0 and brightness is not None:
|
122 |
+
brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
|
123 |
+
img = adjust_brightness(img, brightness_factor)
|
124 |
+
|
125 |
+
if fn_id == 1 and contrast is not None:
|
126 |
+
contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
|
127 |
+
img = adjust_contrast(img, contrast_factor)
|
128 |
+
|
129 |
+
if fn_id == 2 and saturation is not None:
|
130 |
+
saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
|
131 |
+
img = adjust_saturation(img, saturation_factor)
|
132 |
+
|
133 |
+
if fn_id == 3 and hue is not None:
|
134 |
+
hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
|
135 |
+
img = adjust_hue(img, hue_factor)
|
136 |
+
return img
|
137 |
+
|
138 |
+
|
139 |
+
def get_component_locations(self, name, status):
|
140 |
+
components_bbox = self.components_dict[name]
|
141 |
+
if status[0]: # hflip
|
142 |
+
# exchange right and left eye
|
143 |
+
tmp = components_bbox['left_eye']
|
144 |
+
components_bbox['left_eye'] = components_bbox['right_eye']
|
145 |
+
components_bbox['right_eye'] = tmp
|
146 |
+
# modify the width coordinate
|
147 |
+
components_bbox['left_eye'][0] = self.gt_size - components_bbox['left_eye'][0]
|
148 |
+
components_bbox['right_eye'][0] = self.gt_size - components_bbox['right_eye'][0]
|
149 |
+
components_bbox['nose'][0] = self.gt_size - components_bbox['nose'][0]
|
150 |
+
components_bbox['mouth'][0] = self.gt_size - components_bbox['mouth'][0]
|
151 |
+
|
152 |
+
locations_gt = {}
|
153 |
+
locations_in = {}
|
154 |
+
for part in ['left_eye', 'right_eye', 'nose', 'mouth']:
|
155 |
+
mean = components_bbox[part][0:2]
|
156 |
+
half_len = components_bbox[part][2]
|
157 |
+
if 'eye' in part:
|
158 |
+
half_len *= self.eye_enlarge_ratio
|
159 |
+
elif part == 'nose':
|
160 |
+
half_len *= self.nose_enlarge_ratio
|
161 |
+
elif part == 'mouth':
|
162 |
+
half_len *= self.mouth_enlarge_ratio
|
163 |
+
loc = np.hstack((mean - half_len + 1, mean + half_len))
|
164 |
+
loc = torch.from_numpy(loc).float()
|
165 |
+
locations_gt[part] = loc
|
166 |
+
loc_in = loc/(self.gt_size//self.in_size)
|
167 |
+
locations_in[part] = loc_in
|
168 |
+
return locations_gt, locations_in
|
169 |
+
|
170 |
+
|
171 |
+
def __getitem__(self, index):
|
172 |
+
if self.file_client is None:
|
173 |
+
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
174 |
+
|
175 |
+
# load gt image
|
176 |
+
gt_path = self.paths[index]
|
177 |
+
name = osp.basename(gt_path)[:-4]
|
178 |
+
img_bytes = self.file_client.get(gt_path)
|
179 |
+
img_gt = imfrombytes(img_bytes, float32=True)
|
180 |
+
|
181 |
+
# random horizontal flip
|
182 |
+
img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True)
|
183 |
+
|
184 |
+
if self.load_latent_gt:
|
185 |
+
if status[0]:
|
186 |
+
latent_gt = self.latent_gt_dict['hflip'][name]
|
187 |
+
else:
|
188 |
+
latent_gt = self.latent_gt_dict['orig'][name]
|
189 |
+
|
190 |
+
if self.crop_components:
|
191 |
+
locations_gt, locations_in = self.get_component_locations(name, status)
|
192 |
+
|
193 |
+
# generate in image
|
194 |
+
img_in = img_gt
|
195 |
+
if self.use_corrupt:
|
196 |
+
# motion blur
|
197 |
+
if self.use_motion_kernel and random.random() < self.motion_kernel_prob:
|
198 |
+
m_i = random.randint(0,31)
|
199 |
+
k = self.motion_kernels[f'{m_i:02d}']
|
200 |
+
img_in = cv2.filter2D(img_in,-1,k)
|
201 |
+
|
202 |
+
# gaussian blur
|
203 |
+
kernel = gaussian_kernels.random_mixed_kernels(
|
204 |
+
self.kernel_list,
|
205 |
+
self.kernel_prob,
|
206 |
+
self.blur_kernel_size,
|
207 |
+
self.blur_sigma,
|
208 |
+
self.blur_sigma,
|
209 |
+
[-math.pi, math.pi],
|
210 |
+
noise_range=None)
|
211 |
+
img_in = cv2.filter2D(img_in, -1, kernel)
|
212 |
+
|
213 |
+
# downsample
|
214 |
+
scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1])
|
215 |
+
img_in = cv2.resize(img_in, (int(self.gt_size // scale), int(self.gt_size // scale)), interpolation=cv2.INTER_LINEAR)
|
216 |
+
|
217 |
+
# noise
|
218 |
+
if self.noise_range is not None:
|
219 |
+
noise_sigma = np.random.uniform(self.noise_range[0] / 255., self.noise_range[1] / 255.)
|
220 |
+
noise = np.float32(np.random.randn(*(img_in.shape))) * noise_sigma
|
221 |
+
img_in = img_in + noise
|
222 |
+
img_in = np.clip(img_in, 0, 1)
|
223 |
+
|
224 |
+
# jpeg
|
225 |
+
if self.jpeg_range is not None:
|
226 |
+
jpeg_p = np.random.uniform(self.jpeg_range[0], self.jpeg_range[1])
|
227 |
+
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), int(jpeg_p)]
|
228 |
+
_, encimg = cv2.imencode('.jpg', img_in * 255., encode_param)
|
229 |
+
img_in = np.float32(cv2.imdecode(encimg, 1)) / 255.
|
230 |
+
|
231 |
+
# resize to in_size
|
232 |
+
img_in = cv2.resize(img_in, (self.in_size, self.in_size), interpolation=cv2.INTER_LINEAR)
|
233 |
+
|
234 |
+
|
235 |
+
# generate in_large with large degradation
|
236 |
+
img_in_large = img_gt
|
237 |
+
|
238 |
+
if self.use_corrupt:
|
239 |
+
# motion blur
|
240 |
+
if self.use_motion_kernel and random.random() < self.motion_kernel_prob:
|
241 |
+
m_i = random.randint(0,31)
|
242 |
+
k = self.motion_kernels[f'{m_i:02d}']
|
243 |
+
img_in_large = cv2.filter2D(img_in_large,-1,k)
|
244 |
+
|
245 |
+
# gaussian blur
|
246 |
+
kernel = gaussian_kernels.random_mixed_kernels(
|
247 |
+
self.kernel_list,
|
248 |
+
self.kernel_prob,
|
249 |
+
self.blur_kernel_size,
|
250 |
+
self.blur_sigma_large,
|
251 |
+
self.blur_sigma_large,
|
252 |
+
[-math.pi, math.pi],
|
253 |
+
noise_range=None)
|
254 |
+
img_in_large = cv2.filter2D(img_in_large, -1, kernel)
|
255 |
+
|
256 |
+
# downsample
|
257 |
+
scale = np.random.uniform(self.downsample_range_large[0], self.downsample_range_large[1])
|
258 |
+
img_in_large = cv2.resize(img_in_large, (int(self.gt_size // scale), int(self.gt_size // scale)), interpolation=cv2.INTER_LINEAR)
|
259 |
+
|
260 |
+
# noise
|
261 |
+
if self.noise_range_large is not None:
|
262 |
+
noise_sigma = np.random.uniform(self.noise_range_large[0] / 255., self.noise_range_large[1] / 255.)
|
263 |
+
noise = np.float32(np.random.randn(*(img_in_large.shape))) * noise_sigma
|
264 |
+
img_in_large = img_in_large + noise
|
265 |
+
img_in_large = np.clip(img_in_large, 0, 1)
|
266 |
+
|
267 |
+
# jpeg
|
268 |
+
if self.jpeg_range_large is not None:
|
269 |
+
jpeg_p = np.random.uniform(self.jpeg_range_large[0], self.jpeg_range_large[1])
|
270 |
+
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), int(jpeg_p)]
|
271 |
+
_, encimg = cv2.imencode('.jpg', img_in_large * 255., encode_param)
|
272 |
+
img_in_large = np.float32(cv2.imdecode(encimg, 1)) / 255.
|
273 |
+
|
274 |
+
# resize to in_size
|
275 |
+
img_in_large = cv2.resize(img_in_large, (self.in_size, self.in_size), interpolation=cv2.INTER_LINEAR)
|
276 |
+
|
277 |
+
|
278 |
+
# random color jitter (only for lq)
|
279 |
+
if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob):
|
280 |
+
img_in = self.color_jitter(img_in, self.color_jitter_shift)
|
281 |
+
img_in_large = self.color_jitter(img_in_large, self.color_jitter_shift)
|
282 |
+
# random to gray (only for lq)
|
283 |
+
if self.gray_prob and np.random.uniform() < self.gray_prob:
|
284 |
+
img_in = cv2.cvtColor(img_in, cv2.COLOR_BGR2GRAY)
|
285 |
+
img_in = np.tile(img_in[:, :, None], [1, 1, 3])
|
286 |
+
img_in_large = cv2.cvtColor(img_in_large, cv2.COLOR_BGR2GRAY)
|
287 |
+
img_in_large = np.tile(img_in_large[:, :, None], [1, 1, 3])
|
288 |
+
|
289 |
+
# BGR to RGB, HWC to CHW, numpy to tensor
|
290 |
+
img_in, img_in_large, img_gt = img2tensor([img_in, img_in_large, img_gt], bgr2rgb=True, float32=True)
|
291 |
+
|
292 |
+
# random color jitter (pytorch version) (only for lq)
|
293 |
+
if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob):
|
294 |
+
brightness = self.opt.get('brightness', (0.5, 1.5))
|
295 |
+
contrast = self.opt.get('contrast', (0.5, 1.5))
|
296 |
+
saturation = self.opt.get('saturation', (0, 1.5))
|
297 |
+
hue = self.opt.get('hue', (-0.1, 0.1))
|
298 |
+
img_in = self.color_jitter_pt(img_in, brightness, contrast, saturation, hue)
|
299 |
+
img_in_large = self.color_jitter_pt(img_in_large, brightness, contrast, saturation, hue)
|
300 |
+
|
301 |
+
# round and clip
|
302 |
+
img_in = np.clip((img_in * 255.0).round(), 0, 255) / 255.
|
303 |
+
img_in_large = np.clip((img_in_large * 255.0).round(), 0, 255) / 255.
|
304 |
+
|
305 |
+
# Set vgg range_norm=True if use the normalization here
|
306 |
+
# normalize
|
307 |
+
normalize(img_in, self.mean, self.std, inplace=True)
|
308 |
+
normalize(img_in_large, self.mean, self.std, inplace=True)
|
309 |
+
normalize(img_gt, self.mean, self.std, inplace=True)
|
310 |
+
|
311 |
+
return_dict = {'in': img_in, 'in_large_de': img_in_large, 'gt': img_gt, 'gt_path': gt_path}
|
312 |
+
|
313 |
+
if self.crop_components:
|
314 |
+
return_dict['locations_in'] = locations_in
|
315 |
+
return_dict['locations_gt'] = locations_gt
|
316 |
+
|
317 |
+
if self.load_latent_gt:
|
318 |
+
return_dict['latent_gt'] = latent_gt
|
319 |
+
|
320 |
+
return return_dict
|
321 |
+
|
322 |
+
|
323 |
+
def __len__(self):
|
324 |
+
return len(self.paths)
|
blissful_tuner/codeformer/basicsr/data/gaussian_kernels.py
ADDED
@@ -0,0 +1,690 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import numpy as np
|
3 |
+
import random
|
4 |
+
from scipy.ndimage.interpolation import shift
|
5 |
+
from scipy.stats import multivariate_normal
|
6 |
+
|
7 |
+
|
8 |
+
def sigma_matrix2(sig_x, sig_y, theta):
|
9 |
+
"""Calculate the rotated sigma matrix (two dimensional matrix).
|
10 |
+
Args:
|
11 |
+
sig_x (float):
|
12 |
+
sig_y (float):
|
13 |
+
theta (float): Radian measurement.
|
14 |
+
Returns:
|
15 |
+
ndarray: Rotated sigma matrix.
|
16 |
+
"""
|
17 |
+
D = np.array([[sig_x**2, 0], [0, sig_y**2]])
|
18 |
+
U = np.array([[np.cos(theta), -np.sin(theta)],
|
19 |
+
[np.sin(theta), np.cos(theta)]])
|
20 |
+
return np.dot(U, np.dot(D, U.T))
|
21 |
+
|
22 |
+
|
23 |
+
def mesh_grid(kernel_size):
|
24 |
+
"""Generate the mesh grid, centering at zero.
|
25 |
+
Args:
|
26 |
+
kernel_size (int):
|
27 |
+
Returns:
|
28 |
+
xy (ndarray): with the shape (kernel_size, kernel_size, 2)
|
29 |
+
xx (ndarray): with the shape (kernel_size, kernel_size)
|
30 |
+
yy (ndarray): with the shape (kernel_size, kernel_size)
|
31 |
+
"""
|
32 |
+
ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
|
33 |
+
xx, yy = np.meshgrid(ax, ax)
|
34 |
+
xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)),
|
35 |
+
yy.reshape(kernel_size * kernel_size,
|
36 |
+
1))).reshape(kernel_size, kernel_size, 2)
|
37 |
+
return xy, xx, yy
|
38 |
+
|
39 |
+
|
40 |
+
def pdf2(sigma_matrix, grid):
|
41 |
+
"""Calculate PDF of the bivariate Gaussian distribution.
|
42 |
+
Args:
|
43 |
+
sigma_matrix (ndarray): with the shape (2, 2)
|
44 |
+
grid (ndarray): generated by :func:`mesh_grid`,
|
45 |
+
with the shape (K, K, 2), K is the kernel size.
|
46 |
+
Returns:
|
47 |
+
kernel (ndarrray): un-normalized kernel.
|
48 |
+
"""
|
49 |
+
inverse_sigma = np.linalg.inv(sigma_matrix)
|
50 |
+
kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2))
|
51 |
+
return kernel
|
52 |
+
|
53 |
+
|
54 |
+
def cdf2(D, grid):
|
55 |
+
"""Calculate the CDF of the standard bivariate Gaussian distribution.
|
56 |
+
Used in skewed Gaussian distribution.
|
57 |
+
Args:
|
58 |
+
D (ndarrasy): skew matrix.
|
59 |
+
grid (ndarray): generated by :func:`mesh_grid`,
|
60 |
+
with the shape (K, K, 2), K is the kernel size.
|
61 |
+
Returns:
|
62 |
+
cdf (ndarray): skewed cdf.
|
63 |
+
"""
|
64 |
+
rv = multivariate_normal([0, 0], [[1, 0], [0, 1]])
|
65 |
+
grid = np.dot(grid, D)
|
66 |
+
cdf = rv.cdf(grid)
|
67 |
+
return cdf
|
68 |
+
|
69 |
+
|
70 |
+
def bivariate_skew_Gaussian(kernel_size, sig_x, sig_y, theta, D, grid=None):
|
71 |
+
"""Generate a bivariate skew Gaussian kernel.
|
72 |
+
Described in `A multivariate skew normal distribution`_ by Shi et. al (2004).
|
73 |
+
Args:
|
74 |
+
kernel_size (int):
|
75 |
+
sig_x (float):
|
76 |
+
sig_y (float):
|
77 |
+
theta (float): Radian measurement.
|
78 |
+
D (ndarrasy): skew matrix.
|
79 |
+
grid (ndarray, optional): generated by :func:`mesh_grid`,
|
80 |
+
with the shape (K, K, 2), K is the kernel size. Default: None
|
81 |
+
Returns:
|
82 |
+
kernel (ndarray): normalized kernel.
|
83 |
+
.. _A multivariate skew normal distribution:
|
84 |
+
https://www.sciencedirect.com/science/article/pii/S0047259X03001313
|
85 |
+
"""
|
86 |
+
if grid is None:
|
87 |
+
grid, _, _ = mesh_grid(kernel_size)
|
88 |
+
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
|
89 |
+
pdf = pdf2(sigma_matrix, grid)
|
90 |
+
cdf = cdf2(D, grid)
|
91 |
+
kernel = pdf * cdf
|
92 |
+
kernel = kernel / np.sum(kernel)
|
93 |
+
return kernel
|
94 |
+
|
95 |
+
|
96 |
+
def mass_center_shift(kernel_size, kernel):
|
97 |
+
"""Calculate the shift of the mass center of a kenrel.
|
98 |
+
Args:
|
99 |
+
kernel_size (int):
|
100 |
+
kernel (ndarray): normalized kernel.
|
101 |
+
Returns:
|
102 |
+
delta_h (float):
|
103 |
+
delta_w (float):
|
104 |
+
"""
|
105 |
+
ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
|
106 |
+
col_sum, row_sum = np.sum(kernel, axis=0), np.sum(kernel, axis=1)
|
107 |
+
delta_h = np.dot(row_sum, ax)
|
108 |
+
delta_w = np.dot(col_sum, ax)
|
109 |
+
return delta_h, delta_w
|
110 |
+
|
111 |
+
|
112 |
+
def bivariate_skew_Gaussian_center(kernel_size,
|
113 |
+
sig_x,
|
114 |
+
sig_y,
|
115 |
+
theta,
|
116 |
+
D,
|
117 |
+
grid=None):
|
118 |
+
"""Generate a bivariate skew Gaussian kernel at center. Shift with nearest padding.
|
119 |
+
Args:
|
120 |
+
kernel_size (int):
|
121 |
+
sig_x (float):
|
122 |
+
sig_y (float):
|
123 |
+
theta (float): Radian measurement.
|
124 |
+
D (ndarrasy): skew matrix.
|
125 |
+
grid (ndarray, optional): generated by :func:`mesh_grid`,
|
126 |
+
with the shape (K, K, 2), K is the kernel size. Default: None
|
127 |
+
Returns:
|
128 |
+
kernel (ndarray): centered and normalized kernel.
|
129 |
+
"""
|
130 |
+
if grid is None:
|
131 |
+
grid, _, _ = mesh_grid(kernel_size)
|
132 |
+
kernel = bivariate_skew_Gaussian(kernel_size, sig_x, sig_y, theta, D, grid)
|
133 |
+
delta_h, delta_w = mass_center_shift(kernel_size, kernel)
|
134 |
+
kernel = shift(kernel, [-delta_h, -delta_w], mode='nearest')
|
135 |
+
kernel = kernel / np.sum(kernel)
|
136 |
+
return kernel
|
137 |
+
|
138 |
+
|
139 |
+
def bivariate_anisotropic_Gaussian(kernel_size,
|
140 |
+
sig_x,
|
141 |
+
sig_y,
|
142 |
+
theta,
|
143 |
+
grid=None):
|
144 |
+
"""Generate a bivariate anisotropic Gaussian kernel.
|
145 |
+
Args:
|
146 |
+
kernel_size (int):
|
147 |
+
sig_x (float):
|
148 |
+
sig_y (float):
|
149 |
+
theta (float): Radian measurement.
|
150 |
+
grid (ndarray, optional): generated by :func:`mesh_grid`,
|
151 |
+
with the shape (K, K, 2), K is the kernel size. Default: None
|
152 |
+
Returns:
|
153 |
+
kernel (ndarray): normalized kernel.
|
154 |
+
"""
|
155 |
+
if grid is None:
|
156 |
+
grid, _, _ = mesh_grid(kernel_size)
|
157 |
+
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
|
158 |
+
kernel = pdf2(sigma_matrix, grid)
|
159 |
+
kernel = kernel / np.sum(kernel)
|
160 |
+
return kernel
|
161 |
+
|
162 |
+
|
163 |
+
def bivariate_isotropic_Gaussian(kernel_size, sig, grid=None):
|
164 |
+
"""Generate a bivariate isotropic Gaussian kernel.
|
165 |
+
Args:
|
166 |
+
kernel_size (int):
|
167 |
+
sig (float):
|
168 |
+
grid (ndarray, optional): generated by :func:`mesh_grid`,
|
169 |
+
with the shape (K, K, 2), K is the kernel size. Default: None
|
170 |
+
Returns:
|
171 |
+
kernel (ndarray): normalized kernel.
|
172 |
+
"""
|
173 |
+
if grid is None:
|
174 |
+
grid, _, _ = mesh_grid(kernel_size)
|
175 |
+
sigma_matrix = np.array([[sig**2, 0], [0, sig**2]])
|
176 |
+
kernel = pdf2(sigma_matrix, grid)
|
177 |
+
kernel = kernel / np.sum(kernel)
|
178 |
+
return kernel
|
179 |
+
|
180 |
+
|
181 |
+
def bivariate_generalized_Gaussian(kernel_size,
|
182 |
+
sig_x,
|
183 |
+
sig_y,
|
184 |
+
theta,
|
185 |
+
beta,
|
186 |
+
grid=None):
|
187 |
+
"""Generate a bivariate generalized Gaussian kernel.
|
188 |
+
Described in `Parameter Estimation For Multivariate Generalized Gaussian Distributions`_
|
189 |
+
by Pascal et. al (2013).
|
190 |
+
Args:
|
191 |
+
kernel_size (int):
|
192 |
+
sig_x (float):
|
193 |
+
sig_y (float):
|
194 |
+
theta (float): Radian measurement.
|
195 |
+
beta (float): shape parameter, beta = 1 is the normal distribution.
|
196 |
+
grid (ndarray, optional): generated by :func:`mesh_grid`,
|
197 |
+
with the shape (K, K, 2), K is the kernel size. Default: None
|
198 |
+
Returns:
|
199 |
+
kernel (ndarray): normalized kernel.
|
200 |
+
.. _Parameter Estimation For Multivariate Generalized Gaussian Distributions:
|
201 |
+
https://arxiv.org/abs/1302.6498
|
202 |
+
"""
|
203 |
+
if grid is None:
|
204 |
+
grid, _, _ = mesh_grid(kernel_size)
|
205 |
+
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
|
206 |
+
inverse_sigma = np.linalg.inv(sigma_matrix)
|
207 |
+
kernel = np.exp(
|
208 |
+
-0.5 * np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta))
|
209 |
+
kernel = kernel / np.sum(kernel)
|
210 |
+
return kernel
|
211 |
+
|
212 |
+
|
213 |
+
def bivariate_plateau_type1(kernel_size, sig_x, sig_y, theta, beta, grid=None):
|
214 |
+
"""Generate a plateau-like anisotropic kernel.
|
215 |
+
1 / (1+x^(beta))
|
216 |
+
Args:
|
217 |
+
kernel_size (int):
|
218 |
+
sig_x (float):
|
219 |
+
sig_y (float):
|
220 |
+
theta (float): Radian measurement.
|
221 |
+
beta (float): shape parameter, beta = 1 is the normal distribution.
|
222 |
+
grid (ndarray, optional): generated by :func:`mesh_grid`,
|
223 |
+
with the shape (K, K, 2), K is the kernel size. Default: None
|
224 |
+
Returns:
|
225 |
+
kernel (ndarray): normalized kernel.
|
226 |
+
"""
|
227 |
+
if grid is None:
|
228 |
+
grid, _, _ = mesh_grid(kernel_size)
|
229 |
+
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
|
230 |
+
inverse_sigma = np.linalg.inv(sigma_matrix)
|
231 |
+
kernel = np.reciprocal(
|
232 |
+
np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1)
|
233 |
+
kernel = kernel / np.sum(kernel)
|
234 |
+
return kernel
|
235 |
+
|
236 |
+
|
237 |
+
def bivariate_plateau_type1_iso(kernel_size, sig, beta, grid=None):
|
238 |
+
"""Generate a plateau-like isotropic kernel.
|
239 |
+
1 / (1+x^(beta))
|
240 |
+
Args:
|
241 |
+
kernel_size (int):
|
242 |
+
sig (float):
|
243 |
+
beta (float): shape parameter, beta = 1 is the normal distribution.
|
244 |
+
grid (ndarray, optional): generated by :func:`mesh_grid`,
|
245 |
+
with the shape (K, K, 2), K is the kernel size. Default: None
|
246 |
+
Returns:
|
247 |
+
kernel (ndarray): normalized kernel.
|
248 |
+
"""
|
249 |
+
if grid is None:
|
250 |
+
grid, _, _ = mesh_grid(kernel_size)
|
251 |
+
sigma_matrix = np.array([[sig**2, 0], [0, sig**2]])
|
252 |
+
inverse_sigma = np.linalg.inv(sigma_matrix)
|
253 |
+
kernel = np.reciprocal(
|
254 |
+
np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1)
|
255 |
+
kernel = kernel / np.sum(kernel)
|
256 |
+
return kernel
|
257 |
+
|
258 |
+
|
259 |
+
def random_bivariate_skew_Gaussian_center(kernel_size,
|
260 |
+
sigma_x_range,
|
261 |
+
sigma_y_range,
|
262 |
+
rotation_range,
|
263 |
+
noise_range=None,
|
264 |
+
strict=False):
|
265 |
+
"""Randomly generate bivariate skew Gaussian kernels at center.
|
266 |
+
Args:
|
267 |
+
kernel_size (int):
|
268 |
+
sigma_x_range (tuple): [0.6, 5]
|
269 |
+
sigma_y_range (tuple): [0.6, 5]
|
270 |
+
rotation range (tuple): [-math.pi, math.pi]
|
271 |
+
noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
|
272 |
+
Returns:
|
273 |
+
kernel (ndarray):
|
274 |
+
"""
|
275 |
+
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
|
276 |
+
assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
|
277 |
+
assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
|
278 |
+
assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
|
279 |
+
sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
|
280 |
+
sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
|
281 |
+
if strict:
|
282 |
+
sigma_max = np.max([sigma_x, sigma_y])
|
283 |
+
sigma_min = np.min([sigma_x, sigma_y])
|
284 |
+
sigma_x, sigma_y = sigma_max, sigma_min
|
285 |
+
rotation = np.random.uniform(rotation_range[0], rotation_range[1])
|
286 |
+
|
287 |
+
sigma_max = np.max([sigma_x, sigma_y])
|
288 |
+
thres = 3 / sigma_max
|
289 |
+
D = [[np.random.uniform(-thres, thres),
|
290 |
+
np.random.uniform(-thres, thres)],
|
291 |
+
[np.random.uniform(-thres, thres),
|
292 |
+
np.random.uniform(-thres, thres)]]
|
293 |
+
|
294 |
+
kernel = bivariate_skew_Gaussian_center(kernel_size, sigma_x, sigma_y,
|
295 |
+
rotation, D)
|
296 |
+
|
297 |
+
# add multiplicative noise
|
298 |
+
if noise_range is not None:
|
299 |
+
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
|
300 |
+
noise = np.random.uniform(
|
301 |
+
noise_range[0], noise_range[1], size=kernel.shape)
|
302 |
+
kernel = kernel * noise
|
303 |
+
kernel = kernel / np.sum(kernel)
|
304 |
+
if strict:
|
305 |
+
return kernel, sigma_x, sigma_y, rotation, D
|
306 |
+
else:
|
307 |
+
return kernel
|
308 |
+
|
309 |
+
|
310 |
+
def random_bivariate_anisotropic_Gaussian(kernel_size,
|
311 |
+
sigma_x_range,
|
312 |
+
sigma_y_range,
|
313 |
+
rotation_range,
|
314 |
+
noise_range=None,
|
315 |
+
strict=False):
|
316 |
+
"""Randomly generate bivariate anisotropic Gaussian kernels.
|
317 |
+
Args:
|
318 |
+
kernel_size (int):
|
319 |
+
sigma_x_range (tuple): [0.6, 5]
|
320 |
+
sigma_y_range (tuple): [0.6, 5]
|
321 |
+
rotation range (tuple): [-math.pi, math.pi]
|
322 |
+
noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
|
323 |
+
Returns:
|
324 |
+
kernel (ndarray):
|
325 |
+
"""
|
326 |
+
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
|
327 |
+
assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
|
328 |
+
assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
|
329 |
+
assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
|
330 |
+
sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
|
331 |
+
sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
|
332 |
+
if strict:
|
333 |
+
sigma_max = np.max([sigma_x, sigma_y])
|
334 |
+
sigma_min = np.min([sigma_x, sigma_y])
|
335 |
+
sigma_x, sigma_y = sigma_max, sigma_min
|
336 |
+
rotation = np.random.uniform(rotation_range[0], rotation_range[1])
|
337 |
+
|
338 |
+
kernel = bivariate_anisotropic_Gaussian(kernel_size, sigma_x, sigma_y,
|
339 |
+
rotation)
|
340 |
+
|
341 |
+
# add multiplicative noise
|
342 |
+
if noise_range is not None:
|
343 |
+
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
|
344 |
+
noise = np.random.uniform(
|
345 |
+
noise_range[0], noise_range[1], size=kernel.shape)
|
346 |
+
kernel = kernel * noise
|
347 |
+
kernel = kernel / np.sum(kernel)
|
348 |
+
if strict:
|
349 |
+
return kernel, sigma_x, sigma_y, rotation
|
350 |
+
else:
|
351 |
+
return kernel
|
352 |
+
|
353 |
+
|
354 |
+
def random_bivariate_isotropic_Gaussian(kernel_size,
|
355 |
+
sigma_range,
|
356 |
+
noise_range=None,
|
357 |
+
strict=False):
|
358 |
+
"""Randomly generate bivariate isotropic Gaussian kernels.
|
359 |
+
Args:
|
360 |
+
kernel_size (int):
|
361 |
+
sigma_range (tuple): [0.6, 5]
|
362 |
+
noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
|
363 |
+
Returns:
|
364 |
+
kernel (ndarray):
|
365 |
+
"""
|
366 |
+
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
|
367 |
+
assert sigma_range[0] < sigma_range[1], 'Wrong sigma_x_range.'
|
368 |
+
sigma = np.random.uniform(sigma_range[0], sigma_range[1])
|
369 |
+
|
370 |
+
kernel = bivariate_isotropic_Gaussian(kernel_size, sigma)
|
371 |
+
|
372 |
+
# add multiplicative noise
|
373 |
+
if noise_range is not None:
|
374 |
+
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
|
375 |
+
noise = np.random.uniform(
|
376 |
+
noise_range[0], noise_range[1], size=kernel.shape)
|
377 |
+
kernel = kernel * noise
|
378 |
+
kernel = kernel / np.sum(kernel)
|
379 |
+
if strict:
|
380 |
+
return kernel, sigma
|
381 |
+
else:
|
382 |
+
return kernel
|
383 |
+
|
384 |
+
|
385 |
+
def random_bivariate_generalized_Gaussian(kernel_size,
|
386 |
+
sigma_x_range,
|
387 |
+
sigma_y_range,
|
388 |
+
rotation_range,
|
389 |
+
beta_range,
|
390 |
+
noise_range=None,
|
391 |
+
strict=False):
|
392 |
+
"""Randomly generate bivariate generalized Gaussian kernels.
|
393 |
+
Args:
|
394 |
+
kernel_size (int):
|
395 |
+
sigma_x_range (tuple): [0.6, 5]
|
396 |
+
sigma_y_range (tuple): [0.6, 5]
|
397 |
+
rotation range (tuple): [-math.pi, math.pi]
|
398 |
+
beta_range (tuple): [0.5, 8]
|
399 |
+
noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
|
400 |
+
Returns:
|
401 |
+
kernel (ndarray):
|
402 |
+
"""
|
403 |
+
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
|
404 |
+
assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
|
405 |
+
assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
|
406 |
+
assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
|
407 |
+
sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
|
408 |
+
sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
|
409 |
+
if strict:
|
410 |
+
sigma_max = np.max([sigma_x, sigma_y])
|
411 |
+
sigma_min = np.min([sigma_x, sigma_y])
|
412 |
+
sigma_x, sigma_y = sigma_max, sigma_min
|
413 |
+
rotation = np.random.uniform(rotation_range[0], rotation_range[1])
|
414 |
+
if np.random.uniform() < 0.5:
|
415 |
+
beta = np.random.uniform(beta_range[0], 1)
|
416 |
+
else:
|
417 |
+
beta = np.random.uniform(1, beta_range[1])
|
418 |
+
|
419 |
+
kernel = bivariate_generalized_Gaussian(kernel_size, sigma_x, sigma_y,
|
420 |
+
rotation, beta)
|
421 |
+
|
422 |
+
# add multiplicative noise
|
423 |
+
if noise_range is not None:
|
424 |
+
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
|
425 |
+
noise = np.random.uniform(
|
426 |
+
noise_range[0], noise_range[1], size=kernel.shape)
|
427 |
+
kernel = kernel * noise
|
428 |
+
kernel = kernel / np.sum(kernel)
|
429 |
+
if strict:
|
430 |
+
return kernel, sigma_x, sigma_y, rotation, beta
|
431 |
+
else:
|
432 |
+
return kernel
|
433 |
+
|
434 |
+
|
435 |
+
def random_bivariate_plateau_type1(kernel_size,
|
436 |
+
sigma_x_range,
|
437 |
+
sigma_y_range,
|
438 |
+
rotation_range,
|
439 |
+
beta_range,
|
440 |
+
noise_range=None,
|
441 |
+
strict=False):
|
442 |
+
"""Randomly generate bivariate plateau type1 kernels.
|
443 |
+
Args:
|
444 |
+
kernel_size (int):
|
445 |
+
sigma_x_range (tuple): [0.6, 5]
|
446 |
+
sigma_y_range (tuple): [0.6, 5]
|
447 |
+
rotation range (tuple): [-math.pi/2, math.pi/2]
|
448 |
+
beta_range (tuple): [1, 4]
|
449 |
+
noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
|
450 |
+
Returns:
|
451 |
+
kernel (ndarray):
|
452 |
+
"""
|
453 |
+
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
|
454 |
+
assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
|
455 |
+
assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
|
456 |
+
assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
|
457 |
+
sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
|
458 |
+
sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
|
459 |
+
if strict:
|
460 |
+
sigma_max = np.max([sigma_x, sigma_y])
|
461 |
+
sigma_min = np.min([sigma_x, sigma_y])
|
462 |
+
sigma_x, sigma_y = sigma_max, sigma_min
|
463 |
+
rotation = np.random.uniform(rotation_range[0], rotation_range[1])
|
464 |
+
if np.random.uniform() < 0.5:
|
465 |
+
beta = np.random.uniform(beta_range[0], 1)
|
466 |
+
else:
|
467 |
+
beta = np.random.uniform(1, beta_range[1])
|
468 |
+
|
469 |
+
kernel = bivariate_plateau_type1(kernel_size, sigma_x, sigma_y, rotation,
|
470 |
+
beta)
|
471 |
+
|
472 |
+
# add multiplicative noise
|
473 |
+
if noise_range is not None:
|
474 |
+
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
|
475 |
+
noise = np.random.uniform(
|
476 |
+
noise_range[0], noise_range[1], size=kernel.shape)
|
477 |
+
kernel = kernel * noise
|
478 |
+
kernel = kernel / np.sum(kernel)
|
479 |
+
if strict:
|
480 |
+
return kernel, sigma_x, sigma_y, rotation, beta
|
481 |
+
else:
|
482 |
+
return kernel
|
483 |
+
|
484 |
+
|
485 |
+
def random_bivariate_plateau_type1_iso(kernel_size,
|
486 |
+
sigma_range,
|
487 |
+
beta_range,
|
488 |
+
noise_range=None,
|
489 |
+
strict=False):
|
490 |
+
"""Randomly generate bivariate plateau type1 kernels (iso).
|
491 |
+
Args:
|
492 |
+
kernel_size (int):
|
493 |
+
sigma_range (tuple): [0.6, 5]
|
494 |
+
beta_range (tuple): [1, 4]
|
495 |
+
noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
|
496 |
+
Returns:
|
497 |
+
kernel (ndarray):
|
498 |
+
"""
|
499 |
+
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
|
500 |
+
assert sigma_range[0] < sigma_range[1], 'Wrong sigma_x_range.'
|
501 |
+
sigma = np.random.uniform(sigma_range[0], sigma_range[1])
|
502 |
+
beta = np.random.uniform(beta_range[0], beta_range[1])
|
503 |
+
|
504 |
+
kernel = bivariate_plateau_type1_iso(kernel_size, sigma, beta)
|
505 |
+
|
506 |
+
# add multiplicative noise
|
507 |
+
if noise_range is not None:
|
508 |
+
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
|
509 |
+
noise = np.random.uniform(
|
510 |
+
noise_range[0], noise_range[1], size=kernel.shape)
|
511 |
+
kernel = kernel * noise
|
512 |
+
kernel = kernel / np.sum(kernel)
|
513 |
+
if strict:
|
514 |
+
return kernel, sigma, beta
|
515 |
+
else:
|
516 |
+
return kernel
|
517 |
+
|
518 |
+
|
519 |
+
def random_mixed_kernels(kernel_list,
|
520 |
+
kernel_prob,
|
521 |
+
kernel_size=21,
|
522 |
+
sigma_x_range=[0.6, 5],
|
523 |
+
sigma_y_range=[0.6, 5],
|
524 |
+
rotation_range=[-math.pi, math.pi],
|
525 |
+
beta_range=[0.5, 8],
|
526 |
+
noise_range=None):
|
527 |
+
"""Randomly generate mixed kernels.
|
528 |
+
Args:
|
529 |
+
kernel_list (tuple): a list name of kenrel types,
|
530 |
+
support ['iso', 'aniso', 'skew', 'generalized', 'plateau_iso', 'plateau_aniso']
|
531 |
+
kernel_prob (tuple): corresponding kernel probability for each kernel type
|
532 |
+
kernel_size (int):
|
533 |
+
sigma_x_range (tuple): [0.6, 5]
|
534 |
+
sigma_y_range (tuple): [0.6, 5]
|
535 |
+
rotation range (tuple): [-math.pi, math.pi]
|
536 |
+
beta_range (tuple): [0.5, 8]
|
537 |
+
noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
|
538 |
+
Returns:
|
539 |
+
kernel (ndarray):
|
540 |
+
"""
|
541 |
+
kernel_type = random.choices(kernel_list, kernel_prob)[0]
|
542 |
+
if kernel_type == 'iso':
|
543 |
+
kernel = random_bivariate_isotropic_Gaussian(
|
544 |
+
kernel_size, sigma_x_range, noise_range=noise_range)
|
545 |
+
elif kernel_type == 'aniso':
|
546 |
+
kernel = random_bivariate_anisotropic_Gaussian(
|
547 |
+
kernel_size,
|
548 |
+
sigma_x_range,
|
549 |
+
sigma_y_range,
|
550 |
+
rotation_range,
|
551 |
+
noise_range=noise_range)
|
552 |
+
elif kernel_type == 'skew':
|
553 |
+
kernel = random_bivariate_skew_Gaussian_center(
|
554 |
+
kernel_size,
|
555 |
+
sigma_x_range,
|
556 |
+
sigma_y_range,
|
557 |
+
rotation_range,
|
558 |
+
noise_range=noise_range)
|
559 |
+
elif kernel_type == 'generalized':
|
560 |
+
kernel = random_bivariate_generalized_Gaussian(
|
561 |
+
kernel_size,
|
562 |
+
sigma_x_range,
|
563 |
+
sigma_y_range,
|
564 |
+
rotation_range,
|
565 |
+
beta_range,
|
566 |
+
noise_range=noise_range)
|
567 |
+
elif kernel_type == 'plateau_iso':
|
568 |
+
kernel = random_bivariate_plateau_type1_iso(
|
569 |
+
kernel_size, sigma_x_range, beta_range, noise_range=noise_range)
|
570 |
+
elif kernel_type == 'plateau_aniso':
|
571 |
+
kernel = random_bivariate_plateau_type1(
|
572 |
+
kernel_size,
|
573 |
+
sigma_x_range,
|
574 |
+
sigma_y_range,
|
575 |
+
rotation_range,
|
576 |
+
beta_range,
|
577 |
+
noise_range=noise_range)
|
578 |
+
# add multiplicative noise
|
579 |
+
if noise_range is not None:
|
580 |
+
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
|
581 |
+
noise = np.random.uniform(
|
582 |
+
noise_range[0], noise_range[1], size=kernel.shape)
|
583 |
+
kernel = kernel * noise
|
584 |
+
kernel = kernel / np.sum(kernel)
|
585 |
+
return kernel
|
586 |
+
|
587 |
+
|
588 |
+
def show_one_kernel():
|
589 |
+
import matplotlib.pyplot as plt
|
590 |
+
kernel_size = 21
|
591 |
+
|
592 |
+
# bivariate skew Gaussian
|
593 |
+
D = [[0, 0], [0, 0]]
|
594 |
+
D = [[3 / 4, 0], [0, 0.5]]
|
595 |
+
kernel = bivariate_skew_Gaussian_center(kernel_size, 2, 4, -math.pi / 4, D)
|
596 |
+
# bivariate anisotropic Gaussian
|
597 |
+
kernel = bivariate_anisotropic_Gaussian(kernel_size, 2, 4, -math.pi / 4)
|
598 |
+
# bivariate anisotropic Gaussian
|
599 |
+
kernel = bivariate_isotropic_Gaussian(kernel_size, 1)
|
600 |
+
# bivariate generalized Gaussian
|
601 |
+
kernel = bivariate_generalized_Gaussian(
|
602 |
+
kernel_size, 2, 4, -math.pi / 4, beta=4)
|
603 |
+
|
604 |
+
delta_h, delta_w = mass_center_shift(kernel_size, kernel)
|
605 |
+
print(delta_h, delta_w)
|
606 |
+
|
607 |
+
fig, axs = plt.subplots(nrows=2, ncols=2)
|
608 |
+
# axs.set_axis_off()
|
609 |
+
ax = axs[0][0]
|
610 |
+
im = ax.matshow(kernel, cmap='jet', origin='upper')
|
611 |
+
fig.colorbar(im, ax=ax)
|
612 |
+
|
613 |
+
# image
|
614 |
+
ax = axs[0][1]
|
615 |
+
kernel_vis = kernel - np.min(kernel)
|
616 |
+
kernel_vis = kernel_vis / np.max(kernel_vis) * 255.
|
617 |
+
ax.imshow(kernel_vis, interpolation='nearest')
|
618 |
+
|
619 |
+
_, xx, yy = mesh_grid(kernel_size)
|
620 |
+
# contour
|
621 |
+
ax = axs[1][0]
|
622 |
+
CS = ax.contour(xx, yy, kernel, origin='upper')
|
623 |
+
ax.clabel(CS, inline=1, fontsize=3)
|
624 |
+
|
625 |
+
# contourf
|
626 |
+
ax = axs[1][1]
|
627 |
+
kernel = kernel / np.max(kernel)
|
628 |
+
p = ax.contourf(
|
629 |
+
xx, yy, kernel, origin='upper', levels=np.linspace(-0.05, 1.05, 10))
|
630 |
+
fig.colorbar(p)
|
631 |
+
|
632 |
+
plt.show()
|
633 |
+
|
634 |
+
|
635 |
+
def show_plateau_kernel():
|
636 |
+
import matplotlib.pyplot as plt
|
637 |
+
kernel_size = 21
|
638 |
+
|
639 |
+
kernel = plateau_type1(kernel_size, 2, 4, -math.pi / 8, 2, grid=None)
|
640 |
+
kernel_norm = bivariate_isotropic_Gaussian(kernel_size, 5)
|
641 |
+
kernel_gau = bivariate_generalized_Gaussian(
|
642 |
+
kernel_size, 2, 4, -math.pi / 8, 2, grid=None)
|
643 |
+
delta_h, delta_w = mass_center_shift(kernel_size, kernel)
|
644 |
+
print(delta_h, delta_w)
|
645 |
+
|
646 |
+
# kernel_slice = kernel[10, :]
|
647 |
+
# kernel_gau_slice = kernel_gau[10, :]
|
648 |
+
# kernel_norm_slice = kernel_norm[10, :]
|
649 |
+
# fig, ax = plt.subplots()
|
650 |
+
# t = list(range(1, 22))
|
651 |
+
|
652 |
+
# ax.plot(t, kernel_gau_slice)
|
653 |
+
# ax.plot(t, kernel_slice)
|
654 |
+
# ax.plot(t, kernel_norm_slice)
|
655 |
+
|
656 |
+
# t = np.arange(0, 10, 0.1)
|
657 |
+
# y = np.exp(-0.5 * t)
|
658 |
+
# y2 = np.reciprocal(1 + t)
|
659 |
+
# print(t.shape)
|
660 |
+
# print(y.shape)
|
661 |
+
# ax.plot(t, y)
|
662 |
+
# ax.plot(t, y2)
|
663 |
+
# plt.show()
|
664 |
+
|
665 |
+
fig, axs = plt.subplots(nrows=2, ncols=2)
|
666 |
+
# axs.set_axis_off()
|
667 |
+
ax = axs[0][0]
|
668 |
+
im = ax.matshow(kernel, cmap='jet', origin='upper')
|
669 |
+
fig.colorbar(im, ax=ax)
|
670 |
+
|
671 |
+
# image
|
672 |
+
ax = axs[0][1]
|
673 |
+
kernel_vis = kernel - np.min(kernel)
|
674 |
+
kernel_vis = kernel_vis / np.max(kernel_vis) * 255.
|
675 |
+
ax.imshow(kernel_vis, interpolation='nearest')
|
676 |
+
|
677 |
+
_, xx, yy = mesh_grid(kernel_size)
|
678 |
+
# contour
|
679 |
+
ax = axs[1][0]
|
680 |
+
CS = ax.contour(xx, yy, kernel, origin='upper')
|
681 |
+
ax.clabel(CS, inline=1, fontsize=3)
|
682 |
+
|
683 |
+
# contourf
|
684 |
+
ax = axs[1][1]
|
685 |
+
kernel = kernel / np.max(kernel)
|
686 |
+
p = ax.contourf(
|
687 |
+
xx, yy, kernel, origin='upper', levels=np.linspace(-0.05, 1.05, 10))
|
688 |
+
fig.colorbar(p)
|
689 |
+
|
690 |
+
plt.show()
|
blissful_tuner/codeformer/basicsr/data/paired_image_dataset.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils import data as data
|
2 |
+
from torchvision.transforms.functional import normalize
|
3 |
+
|
4 |
+
from codeformer.basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file
|
5 |
+
from codeformer.basicsr.data.transforms import augment, paired_random_crop
|
6 |
+
from codeformer.basicsr.utils import FileClient, imfrombytes, img2tensor
|
7 |
+
from codeformer.basicsr.utils.registry import DATASET_REGISTRY
|
8 |
+
|
9 |
+
|
10 |
+
@DATASET_REGISTRY.register()
|
11 |
+
class PairedImageDataset(data.Dataset):
|
12 |
+
"""Paired image dataset for image restoration.
|
13 |
+
|
14 |
+
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and
|
15 |
+
GT image pairs.
|
16 |
+
|
17 |
+
There are three modes:
|
18 |
+
1. 'lmdb': Use lmdb files.
|
19 |
+
If opt['io_backend'] == lmdb.
|
20 |
+
2. 'meta_info_file': Use meta information file to generate paths.
|
21 |
+
If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None.
|
22 |
+
3. 'folder': Scan folders to generate paths.
|
23 |
+
The rest.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
opt (dict): Config for train datasets. It contains the following keys:
|
27 |
+
dataroot_gt (str): Data root path for gt.
|
28 |
+
dataroot_lq (str): Data root path for lq.
|
29 |
+
meta_info_file (str): Path for meta information file.
|
30 |
+
io_backend (dict): IO backend type and other kwarg.
|
31 |
+
filename_tmpl (str): Template for each filename. Note that the
|
32 |
+
template excludes the file extension. Default: '{}'.
|
33 |
+
gt_size (int): Cropped patched size for gt patches.
|
34 |
+
use_flip (bool): Use horizontal flips.
|
35 |
+
use_rot (bool): Use rotation (use vertical flip and transposing h
|
36 |
+
and w for implementation).
|
37 |
+
|
38 |
+
scale (bool): Scale, which will be added automatically.
|
39 |
+
phase (str): 'train' or 'val'.
|
40 |
+
"""
|
41 |
+
|
42 |
+
def __init__(self, opt):
|
43 |
+
super(PairedImageDataset, self).__init__()
|
44 |
+
self.opt = opt
|
45 |
+
# file client (io backend)
|
46 |
+
self.file_client = None
|
47 |
+
self.io_backend_opt = opt['io_backend']
|
48 |
+
self.mean = opt['mean'] if 'mean' in opt else None
|
49 |
+
self.std = opt['std'] if 'std' in opt else None
|
50 |
+
|
51 |
+
self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
|
52 |
+
if 'filename_tmpl' in opt:
|
53 |
+
self.filename_tmpl = opt['filename_tmpl']
|
54 |
+
else:
|
55 |
+
self.filename_tmpl = '{}'
|
56 |
+
|
57 |
+
if self.io_backend_opt['type'] == 'lmdb':
|
58 |
+
self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
|
59 |
+
self.io_backend_opt['client_keys'] = ['lq', 'gt']
|
60 |
+
self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
|
61 |
+
elif 'meta_info_file' in self.opt and self.opt['meta_info_file'] is not None:
|
62 |
+
self.paths = paired_paths_from_meta_info_file([self.lq_folder, self.gt_folder], ['lq', 'gt'],
|
63 |
+
self.opt['meta_info_file'], self.filename_tmpl)
|
64 |
+
else:
|
65 |
+
self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
|
66 |
+
|
67 |
+
def __getitem__(self, index):
|
68 |
+
if self.file_client is None:
|
69 |
+
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
70 |
+
|
71 |
+
scale = self.opt['scale']
|
72 |
+
|
73 |
+
# Load gt and lq images. Dimension order: HWC; channel order: BGR;
|
74 |
+
# image range: [0, 1], float32.
|
75 |
+
gt_path = self.paths[index]['gt_path']
|
76 |
+
img_bytes = self.file_client.get(gt_path, 'gt')
|
77 |
+
img_gt = imfrombytes(img_bytes, float32=True)
|
78 |
+
lq_path = self.paths[index]['lq_path']
|
79 |
+
img_bytes = self.file_client.get(lq_path, 'lq')
|
80 |
+
img_lq = imfrombytes(img_bytes, float32=True)
|
81 |
+
|
82 |
+
# augmentation for training
|
83 |
+
if self.opt['phase'] == 'train':
|
84 |
+
gt_size = self.opt['gt_size']
|
85 |
+
# random crop
|
86 |
+
img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
|
87 |
+
# flip, rotation
|
88 |
+
img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_flip'], self.opt['use_rot'])
|
89 |
+
|
90 |
+
# TODO: color space transform
|
91 |
+
# BGR to RGB, HWC to CHW, numpy to tensor
|
92 |
+
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
|
93 |
+
# normalize
|
94 |
+
if self.mean is not None or self.std is not None:
|
95 |
+
normalize(img_lq, self.mean, self.std, inplace=True)
|
96 |
+
normalize(img_gt, self.mean, self.std, inplace=True)
|
97 |
+
|
98 |
+
return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
|
99 |
+
|
100 |
+
def __len__(self):
|
101 |
+
return len(self.paths)
|
blissful_tuner/codeformer/basicsr/data/prefetch_dataloader.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import queue as Queue
|
2 |
+
import threading
|
3 |
+
import torch
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
|
6 |
+
|
7 |
+
class PrefetchGenerator(threading.Thread):
|
8 |
+
"""A general prefetch generator.
|
9 |
+
|
10 |
+
Ref:
|
11 |
+
https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
|
12 |
+
|
13 |
+
Args:
|
14 |
+
generator: Python generator.
|
15 |
+
num_prefetch_queue (int): Number of prefetch queue.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, generator, num_prefetch_queue):
|
19 |
+
threading.Thread.__init__(self)
|
20 |
+
self.queue = Queue.Queue(num_prefetch_queue)
|
21 |
+
self.generator = generator
|
22 |
+
self.daemon = True
|
23 |
+
self.start()
|
24 |
+
|
25 |
+
def run(self):
|
26 |
+
for item in self.generator:
|
27 |
+
self.queue.put(item)
|
28 |
+
self.queue.put(None)
|
29 |
+
|
30 |
+
def __next__(self):
|
31 |
+
next_item = self.queue.get()
|
32 |
+
if next_item is None:
|
33 |
+
raise StopIteration
|
34 |
+
return next_item
|
35 |
+
|
36 |
+
def __iter__(self):
|
37 |
+
return self
|
38 |
+
|
39 |
+
|
40 |
+
class PrefetchDataLoader(DataLoader):
|
41 |
+
"""Prefetch version of dataloader.
|
42 |
+
|
43 |
+
Ref:
|
44 |
+
https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
|
45 |
+
|
46 |
+
TODO:
|
47 |
+
Need to test on single gpu and ddp (multi-gpu). There is a known issue in
|
48 |
+
ddp.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
num_prefetch_queue (int): Number of prefetch queue.
|
52 |
+
kwargs (dict): Other arguments for dataloader.
|
53 |
+
"""
|
54 |
+
|
55 |
+
def __init__(self, num_prefetch_queue, **kwargs):
|
56 |
+
self.num_prefetch_queue = num_prefetch_queue
|
57 |
+
super(PrefetchDataLoader, self).__init__(**kwargs)
|
58 |
+
|
59 |
+
def __iter__(self):
|
60 |
+
return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
|
61 |
+
|
62 |
+
|
63 |
+
class CPUPrefetcher():
|
64 |
+
"""CPU prefetcher.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
loader: Dataloader.
|
68 |
+
"""
|
69 |
+
|
70 |
+
def __init__(self, loader):
|
71 |
+
self.ori_loader = loader
|
72 |
+
self.loader = iter(loader)
|
73 |
+
|
74 |
+
def next(self):
|
75 |
+
try:
|
76 |
+
return next(self.loader)
|
77 |
+
except StopIteration:
|
78 |
+
return None
|
79 |
+
|
80 |
+
def reset(self):
|
81 |
+
self.loader = iter(self.ori_loader)
|
82 |
+
|
83 |
+
|
84 |
+
class CUDAPrefetcher():
|
85 |
+
"""CUDA prefetcher.
|
86 |
+
|
87 |
+
Ref:
|
88 |
+
https://github.com/NVIDIA/apex/issues/304#
|
89 |
+
|
90 |
+
It may consums more GPU memory.
|
91 |
+
|
92 |
+
Args:
|
93 |
+
loader: Dataloader.
|
94 |
+
opt (dict): Options.
|
95 |
+
"""
|
96 |
+
|
97 |
+
def __init__(self, loader, opt):
|
98 |
+
self.ori_loader = loader
|
99 |
+
self.loader = iter(loader)
|
100 |
+
self.opt = opt
|
101 |
+
self.stream = torch.cuda.Stream()
|
102 |
+
self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
|
103 |
+
self.preload()
|
104 |
+
|
105 |
+
def preload(self):
|
106 |
+
try:
|
107 |
+
self.batch = next(self.loader) # self.batch is a dict
|
108 |
+
except StopIteration:
|
109 |
+
self.batch = None
|
110 |
+
return None
|
111 |
+
# put tensors to gpu
|
112 |
+
with torch.cuda.stream(self.stream):
|
113 |
+
for k, v in self.batch.items():
|
114 |
+
if torch.is_tensor(v):
|
115 |
+
self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
|
116 |
+
|
117 |
+
def next(self):
|
118 |
+
torch.cuda.current_stream().wait_stream(self.stream)
|
119 |
+
batch = self.batch
|
120 |
+
self.preload()
|
121 |
+
return batch
|
122 |
+
|
123 |
+
def reset(self):
|
124 |
+
self.loader = iter(self.ori_loader)
|
125 |
+
self.preload()
|
blissful_tuner/codeformer/basicsr/data/transforms.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import random
|
3 |
+
|
4 |
+
|
5 |
+
def mod_crop(img, scale):
|
6 |
+
"""Mod crop images, used during testing.
|
7 |
+
|
8 |
+
Args:
|
9 |
+
img (ndarray): Input image.
|
10 |
+
scale (int): Scale factor.
|
11 |
+
|
12 |
+
Returns:
|
13 |
+
ndarray: Result image.
|
14 |
+
"""
|
15 |
+
img = img.copy()
|
16 |
+
if img.ndim in (2, 3):
|
17 |
+
h, w = img.shape[0], img.shape[1]
|
18 |
+
h_remainder, w_remainder = h % scale, w % scale
|
19 |
+
img = img[:h - h_remainder, :w - w_remainder, ...]
|
20 |
+
else:
|
21 |
+
raise ValueError(f'Wrong img ndim: {img.ndim}.')
|
22 |
+
return img
|
23 |
+
|
24 |
+
|
25 |
+
def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path):
|
26 |
+
"""Paired random crop.
|
27 |
+
|
28 |
+
It crops lists of lq and gt images with corresponding locations.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
img_gts (list[ndarray] | ndarray): GT images. Note that all images
|
32 |
+
should have the same shape. If the input is an ndarray, it will
|
33 |
+
be transformed to a list containing itself.
|
34 |
+
img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
|
35 |
+
should have the same shape. If the input is an ndarray, it will
|
36 |
+
be transformed to a list containing itself.
|
37 |
+
gt_patch_size (int): GT patch size.
|
38 |
+
scale (int): Scale factor.
|
39 |
+
gt_path (str): Path to ground-truth.
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
list[ndarray] | ndarray: GT images and LQ images. If returned results
|
43 |
+
only have one element, just return ndarray.
|
44 |
+
"""
|
45 |
+
|
46 |
+
if not isinstance(img_gts, list):
|
47 |
+
img_gts = [img_gts]
|
48 |
+
if not isinstance(img_lqs, list):
|
49 |
+
img_lqs = [img_lqs]
|
50 |
+
|
51 |
+
h_lq, w_lq, _ = img_lqs[0].shape
|
52 |
+
h_gt, w_gt, _ = img_gts[0].shape
|
53 |
+
lq_patch_size = gt_patch_size // scale
|
54 |
+
|
55 |
+
if h_gt != h_lq * scale or w_gt != w_lq * scale:
|
56 |
+
raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
|
57 |
+
f'multiplication of LQ ({h_lq}, {w_lq}).')
|
58 |
+
if h_lq < lq_patch_size or w_lq < lq_patch_size:
|
59 |
+
raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
|
60 |
+
f'({lq_patch_size}, {lq_patch_size}). '
|
61 |
+
f'Please remove {gt_path}.')
|
62 |
+
|
63 |
+
# randomly choose top and left coordinates for lq patch
|
64 |
+
top = random.randint(0, h_lq - lq_patch_size)
|
65 |
+
left = random.randint(0, w_lq - lq_patch_size)
|
66 |
+
|
67 |
+
# crop lq patch
|
68 |
+
img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
|
69 |
+
|
70 |
+
# crop corresponding gt patch
|
71 |
+
top_gt, left_gt = int(top * scale), int(left * scale)
|
72 |
+
img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
|
73 |
+
if len(img_gts) == 1:
|
74 |
+
img_gts = img_gts[0]
|
75 |
+
if len(img_lqs) == 1:
|
76 |
+
img_lqs = img_lqs[0]
|
77 |
+
return img_gts, img_lqs
|
78 |
+
|
79 |
+
|
80 |
+
def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
|
81 |
+
"""Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
|
82 |
+
|
83 |
+
We use vertical flip and transpose for rotation implementation.
|
84 |
+
All the images in the list use the same augmentation.
|
85 |
+
|
86 |
+
Args:
|
87 |
+
imgs (list[ndarray] | ndarray): Images to be augmented. If the input
|
88 |
+
is an ndarray, it will be transformed to a list.
|
89 |
+
hflip (bool): Horizontal flip. Default: True.
|
90 |
+
rotation (bool): Ratotation. Default: True.
|
91 |
+
flows (list[ndarray]: Flows to be augmented. If the input is an
|
92 |
+
ndarray, it will be transformed to a list.
|
93 |
+
Dimension is (h, w, 2). Default: None.
|
94 |
+
return_status (bool): Return the status of flip and rotation.
|
95 |
+
Default: False.
|
96 |
+
|
97 |
+
Returns:
|
98 |
+
list[ndarray] | ndarray: Augmented images and flows. If returned
|
99 |
+
results only have one element, just return ndarray.
|
100 |
+
|
101 |
+
"""
|
102 |
+
hflip = hflip and random.random() < 0.5
|
103 |
+
vflip = rotation and random.random() < 0.5
|
104 |
+
rot90 = rotation and random.random() < 0.5
|
105 |
+
|
106 |
+
def _augment(img):
|
107 |
+
if hflip: # horizontal
|
108 |
+
cv2.flip(img, 1, img)
|
109 |
+
if vflip: # vertical
|
110 |
+
cv2.flip(img, 0, img)
|
111 |
+
if rot90:
|
112 |
+
img = img.transpose(1, 0, 2)
|
113 |
+
return img
|
114 |
+
|
115 |
+
def _augment_flow(flow):
|
116 |
+
if hflip: # horizontal
|
117 |
+
cv2.flip(flow, 1, flow)
|
118 |
+
flow[:, :, 0] *= -1
|
119 |
+
if vflip: # vertical
|
120 |
+
cv2.flip(flow, 0, flow)
|
121 |
+
flow[:, :, 1] *= -1
|
122 |
+
if rot90:
|
123 |
+
flow = flow.transpose(1, 0, 2)
|
124 |
+
flow = flow[:, :, [1, 0]]
|
125 |
+
return flow
|
126 |
+
|
127 |
+
if not isinstance(imgs, list):
|
128 |
+
imgs = [imgs]
|
129 |
+
imgs = [_augment(img) for img in imgs]
|
130 |
+
if len(imgs) == 1:
|
131 |
+
imgs = imgs[0]
|
132 |
+
|
133 |
+
if flows is not None:
|
134 |
+
if not isinstance(flows, list):
|
135 |
+
flows = [flows]
|
136 |
+
flows = [_augment_flow(flow) for flow in flows]
|
137 |
+
if len(flows) == 1:
|
138 |
+
flows = flows[0]
|
139 |
+
return imgs, flows
|
140 |
+
else:
|
141 |
+
if return_status:
|
142 |
+
return imgs, (hflip, vflip, rot90)
|
143 |
+
else:
|
144 |
+
return imgs
|
145 |
+
|
146 |
+
|
147 |
+
def img_rotate(img, angle, center=None, scale=1.0):
|
148 |
+
"""Rotate image.
|
149 |
+
|
150 |
+
Args:
|
151 |
+
img (ndarray): Image to be rotated.
|
152 |
+
angle (float): Rotation angle in degrees. Positive values mean
|
153 |
+
counter-clockwise rotation.
|
154 |
+
center (tuple[int]): Rotation center. If the center is None,
|
155 |
+
initialize it as the center of the image. Default: None.
|
156 |
+
scale (float): Isotropic scale factor. Default: 1.0.
|
157 |
+
"""
|
158 |
+
(h, w) = img.shape[:2]
|
159 |
+
|
160 |
+
if center is None:
|
161 |
+
center = (w // 2, h // 2)
|
162 |
+
|
163 |
+
matrix = cv2.getRotationMatrix2D(center, angle, scale)
|
164 |
+
rotated_img = cv2.warpAffine(img, matrix, (w, h))
|
165 |
+
return rotated_img
|
blissful_tuner/codeformer/basicsr/losses/__init__.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from copy import deepcopy
|
2 |
+
|
3 |
+
from codeformer.basicsr.utils import get_root_logger
|
4 |
+
from codeformer.basicsr.utils.registry import LOSS_REGISTRY
|
5 |
+
from .losses import (CharbonnierLoss, GANLoss, L1Loss, MSELoss, PerceptualLoss, WeightedTVLoss, g_path_regularize,
|
6 |
+
gradient_penalty_loss, r1_penalty)
|
7 |
+
|
8 |
+
__all__ = [
|
9 |
+
'L1Loss', 'MSELoss', 'CharbonnierLoss', 'WeightedTVLoss', 'PerceptualLoss', 'GANLoss', 'gradient_penalty_loss',
|
10 |
+
'r1_penalty', 'g_path_regularize'
|
11 |
+
]
|
12 |
+
|
13 |
+
|
14 |
+
def build_loss(opt):
|
15 |
+
"""Build loss from options.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
opt (dict): Configuration. It must constain:
|
19 |
+
type (str): Model type.
|
20 |
+
"""
|
21 |
+
opt = deepcopy(opt)
|
22 |
+
loss_type = opt.pop('type')
|
23 |
+
loss = LOSS_REGISTRY.get(loss_type)(**opt)
|
24 |
+
logger = get_root_logger()
|
25 |
+
logger.info(f'Loss [{loss.__class__.__name__}] is created.')
|
26 |
+
return loss
|
blissful_tuner/codeformer/basicsr/losses/loss_util.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
from torch.nn import functional as F
|
3 |
+
|
4 |
+
|
5 |
+
def reduce_loss(loss, reduction):
|
6 |
+
"""Reduce loss as specified.
|
7 |
+
|
8 |
+
Args:
|
9 |
+
loss (Tensor): Elementwise loss tensor.
|
10 |
+
reduction (str): Options are 'none', 'mean' and 'sum'.
|
11 |
+
|
12 |
+
Returns:
|
13 |
+
Tensor: Reduced loss tensor.
|
14 |
+
"""
|
15 |
+
reduction_enum = F._Reduction.get_enum(reduction)
|
16 |
+
# none: 0, elementwise_mean:1, sum: 2
|
17 |
+
if reduction_enum == 0:
|
18 |
+
return loss
|
19 |
+
elif reduction_enum == 1:
|
20 |
+
return loss.mean()
|
21 |
+
else:
|
22 |
+
return loss.sum()
|
23 |
+
|
24 |
+
|
25 |
+
def weight_reduce_loss(loss, weight=None, reduction='mean'):
|
26 |
+
"""Apply element-wise weight and reduce loss.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
loss (Tensor): Element-wise loss.
|
30 |
+
weight (Tensor): Element-wise weights. Default: None.
|
31 |
+
reduction (str): Same as built-in losses of PyTorch. Options are
|
32 |
+
'none', 'mean' and 'sum'. Default: 'mean'.
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
Tensor: Loss values.
|
36 |
+
"""
|
37 |
+
# if weight is specified, apply element-wise weight
|
38 |
+
if weight is not None:
|
39 |
+
assert weight.dim() == loss.dim()
|
40 |
+
assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
|
41 |
+
loss = loss * weight
|
42 |
+
|
43 |
+
# if weight is not specified or reduction is sum, just reduce the loss
|
44 |
+
if weight is None or reduction == 'sum':
|
45 |
+
loss = reduce_loss(loss, reduction)
|
46 |
+
# if reduction is mean, then compute mean over weight region
|
47 |
+
elif reduction == 'mean':
|
48 |
+
if weight.size(1) > 1:
|
49 |
+
weight = weight.sum()
|
50 |
+
else:
|
51 |
+
weight = weight.sum() * loss.size(1)
|
52 |
+
loss = loss.sum() / weight
|
53 |
+
|
54 |
+
return loss
|
55 |
+
|
56 |
+
|
57 |
+
def weighted_loss(loss_func):
|
58 |
+
"""Create a weighted version of a given loss function.
|
59 |
+
|
60 |
+
To use this decorator, the loss function must have the signature like
|
61 |
+
`loss_func(pred, target, **kwargs)`. The function only needs to compute
|
62 |
+
element-wise loss without any reduction. This decorator will add weight
|
63 |
+
and reduction arguments to the function. The decorated function will have
|
64 |
+
the signature like `loss_func(pred, target, weight=None, reduction='mean',
|
65 |
+
**kwargs)`.
|
66 |
+
|
67 |
+
:Example:
|
68 |
+
|
69 |
+
>>> import torch
|
70 |
+
>>> @weighted_loss
|
71 |
+
>>> def l1_loss(pred, target):
|
72 |
+
>>> return (pred - target).abs()
|
73 |
+
|
74 |
+
>>> pred = torch.Tensor([0, 2, 3])
|
75 |
+
>>> target = torch.Tensor([1, 1, 1])
|
76 |
+
>>> weight = torch.Tensor([1, 0, 1])
|
77 |
+
|
78 |
+
>>> l1_loss(pred, target)
|
79 |
+
tensor(1.3333)
|
80 |
+
>>> l1_loss(pred, target, weight)
|
81 |
+
tensor(1.5000)
|
82 |
+
>>> l1_loss(pred, target, reduction='none')
|
83 |
+
tensor([1., 1., 2.])
|
84 |
+
>>> l1_loss(pred, target, weight, reduction='sum')
|
85 |
+
tensor(3.)
|
86 |
+
"""
|
87 |
+
|
88 |
+
@functools.wraps(loss_func)
|
89 |
+
def wrapper(pred, target, weight=None, reduction='mean', **kwargs):
|
90 |
+
# get element-wise loss
|
91 |
+
loss = loss_func(pred, target, **kwargs)
|
92 |
+
loss = weight_reduce_loss(loss, weight, reduction)
|
93 |
+
return loss
|
94 |
+
|
95 |
+
return wrapper
|
blissful_tuner/codeformer/basicsr/losses/losses.py
ADDED
@@ -0,0 +1,455 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import lpips
|
3 |
+
import torch
|
4 |
+
from torch import autograd as autograd
|
5 |
+
from torch import nn as nn
|
6 |
+
from torch.nn import functional as F
|
7 |
+
|
8 |
+
from codeformer.basicsr.archs.vgg_arch import VGGFeatureExtractor
|
9 |
+
from codeformer.basicsr.utils.registry import LOSS_REGISTRY
|
10 |
+
from .loss_util import weighted_loss
|
11 |
+
|
12 |
+
_reduction_modes = ['none', 'mean', 'sum']
|
13 |
+
|
14 |
+
|
15 |
+
@weighted_loss
|
16 |
+
def l1_loss(pred, target):
|
17 |
+
return F.l1_loss(pred, target, reduction='none')
|
18 |
+
|
19 |
+
|
20 |
+
@weighted_loss
|
21 |
+
def mse_loss(pred, target):
|
22 |
+
return F.mse_loss(pred, target, reduction='none')
|
23 |
+
|
24 |
+
|
25 |
+
@weighted_loss
|
26 |
+
def charbonnier_loss(pred, target, eps=1e-12):
|
27 |
+
return torch.sqrt((pred - target)**2 + eps)
|
28 |
+
|
29 |
+
|
30 |
+
@LOSS_REGISTRY.register()
|
31 |
+
class L1Loss(nn.Module):
|
32 |
+
"""L1 (mean absolute error, MAE) loss.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
loss_weight (float): Loss weight for L1 loss. Default: 1.0.
|
36 |
+
reduction (str): Specifies the reduction to apply to the output.
|
37 |
+
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(self, loss_weight=1.0, reduction='mean'):
|
41 |
+
super(L1Loss, self).__init__()
|
42 |
+
if reduction not in ['none', 'mean', 'sum']:
|
43 |
+
raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
|
44 |
+
|
45 |
+
self.loss_weight = loss_weight
|
46 |
+
self.reduction = reduction
|
47 |
+
|
48 |
+
def forward(self, pred, target, weight=None, **kwargs):
|
49 |
+
"""
|
50 |
+
Args:
|
51 |
+
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
|
52 |
+
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
|
53 |
+
weight (Tensor, optional): of shape (N, C, H, W). Element-wise
|
54 |
+
weights. Default: None.
|
55 |
+
"""
|
56 |
+
return self.loss_weight * l1_loss(pred, target, weight, reduction=self.reduction)
|
57 |
+
|
58 |
+
|
59 |
+
@LOSS_REGISTRY.register()
|
60 |
+
class MSELoss(nn.Module):
|
61 |
+
"""MSE (L2) loss.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
loss_weight (float): Loss weight for MSE loss. Default: 1.0.
|
65 |
+
reduction (str): Specifies the reduction to apply to the output.
|
66 |
+
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
|
67 |
+
"""
|
68 |
+
|
69 |
+
def __init__(self, loss_weight=1.0, reduction='mean'):
|
70 |
+
super(MSELoss, self).__init__()
|
71 |
+
if reduction not in ['none', 'mean', 'sum']:
|
72 |
+
raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
|
73 |
+
|
74 |
+
self.loss_weight = loss_weight
|
75 |
+
self.reduction = reduction
|
76 |
+
|
77 |
+
def forward(self, pred, target, weight=None, **kwargs):
|
78 |
+
"""
|
79 |
+
Args:
|
80 |
+
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
|
81 |
+
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
|
82 |
+
weight (Tensor, optional): of shape (N, C, H, W). Element-wise
|
83 |
+
weights. Default: None.
|
84 |
+
"""
|
85 |
+
return self.loss_weight * mse_loss(pred, target, weight, reduction=self.reduction)
|
86 |
+
|
87 |
+
|
88 |
+
@LOSS_REGISTRY.register()
|
89 |
+
class CharbonnierLoss(nn.Module):
|
90 |
+
"""Charbonnier loss (one variant of Robust L1Loss, a differentiable
|
91 |
+
variant of L1Loss).
|
92 |
+
|
93 |
+
Described in "Deep Laplacian Pyramid Networks for Fast and Accurate
|
94 |
+
Super-Resolution".
|
95 |
+
|
96 |
+
Args:
|
97 |
+
loss_weight (float): Loss weight for L1 loss. Default: 1.0.
|
98 |
+
reduction (str): Specifies the reduction to apply to the output.
|
99 |
+
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
|
100 |
+
eps (float): A value used to control the curvature near zero.
|
101 |
+
Default: 1e-12.
|
102 |
+
"""
|
103 |
+
|
104 |
+
def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-12):
|
105 |
+
super(CharbonnierLoss, self).__init__()
|
106 |
+
if reduction not in ['none', 'mean', 'sum']:
|
107 |
+
raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
|
108 |
+
|
109 |
+
self.loss_weight = loss_weight
|
110 |
+
self.reduction = reduction
|
111 |
+
self.eps = eps
|
112 |
+
|
113 |
+
def forward(self, pred, target, weight=None, **kwargs):
|
114 |
+
"""
|
115 |
+
Args:
|
116 |
+
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
|
117 |
+
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
|
118 |
+
weight (Tensor, optional): of shape (N, C, H, W). Element-wise
|
119 |
+
weights. Default: None.
|
120 |
+
"""
|
121 |
+
return self.loss_weight * charbonnier_loss(pred, target, weight, eps=self.eps, reduction=self.reduction)
|
122 |
+
|
123 |
+
|
124 |
+
@LOSS_REGISTRY.register()
|
125 |
+
class WeightedTVLoss(L1Loss):
|
126 |
+
"""Weighted TV loss.
|
127 |
+
|
128 |
+
Args:
|
129 |
+
loss_weight (float): Loss weight. Default: 1.0.
|
130 |
+
"""
|
131 |
+
|
132 |
+
def __init__(self, loss_weight=1.0):
|
133 |
+
super(WeightedTVLoss, self).__init__(loss_weight=loss_weight)
|
134 |
+
|
135 |
+
def forward(self, pred, weight=None):
|
136 |
+
y_diff = super(WeightedTVLoss, self).forward(pred[:, :, :-1, :], pred[:, :, 1:, :], weight=weight[:, :, :-1, :])
|
137 |
+
x_diff = super(WeightedTVLoss, self).forward(pred[:, :, :, :-1], pred[:, :, :, 1:], weight=weight[:, :, :, :-1])
|
138 |
+
|
139 |
+
loss = x_diff + y_diff
|
140 |
+
|
141 |
+
return loss
|
142 |
+
|
143 |
+
|
144 |
+
@LOSS_REGISTRY.register()
|
145 |
+
class PerceptualLoss(nn.Module):
|
146 |
+
"""Perceptual loss with commonly used style loss.
|
147 |
+
|
148 |
+
Args:
|
149 |
+
layer_weights (dict): The weight for each layer of vgg feature.
|
150 |
+
Here is an example: {'conv5_4': 1.}, which means the conv5_4
|
151 |
+
feature layer (before relu5_4) will be extracted with weight
|
152 |
+
1.0 in calculting losses.
|
153 |
+
vgg_type (str): The type of vgg network used as feature extractor.
|
154 |
+
Default: 'vgg19'.
|
155 |
+
use_input_norm (bool): If True, normalize the input image in vgg.
|
156 |
+
Default: True.
|
157 |
+
range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
|
158 |
+
Default: False.
|
159 |
+
perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
|
160 |
+
loss will be calculated and the loss will multiplied by the
|
161 |
+
weight. Default: 1.0.
|
162 |
+
style_weight (float): If `style_weight > 0`, the style loss will be
|
163 |
+
calculated and the loss will multiplied by the weight.
|
164 |
+
Default: 0.
|
165 |
+
criterion (str): Criterion used for perceptual loss. Default: 'l1'.
|
166 |
+
"""
|
167 |
+
|
168 |
+
def __init__(self,
|
169 |
+
layer_weights,
|
170 |
+
vgg_type='vgg19',
|
171 |
+
use_input_norm=True,
|
172 |
+
range_norm=False,
|
173 |
+
perceptual_weight=1.0,
|
174 |
+
style_weight=0.,
|
175 |
+
criterion='l1'):
|
176 |
+
super(PerceptualLoss, self).__init__()
|
177 |
+
self.perceptual_weight = perceptual_weight
|
178 |
+
self.style_weight = style_weight
|
179 |
+
self.layer_weights = layer_weights
|
180 |
+
self.vgg = VGGFeatureExtractor(
|
181 |
+
layer_name_list=list(layer_weights.keys()),
|
182 |
+
vgg_type=vgg_type,
|
183 |
+
use_input_norm=use_input_norm,
|
184 |
+
range_norm=range_norm)
|
185 |
+
|
186 |
+
self.criterion_type = criterion
|
187 |
+
if self.criterion_type == 'l1':
|
188 |
+
self.criterion = torch.nn.L1Loss()
|
189 |
+
elif self.criterion_type == 'l2':
|
190 |
+
self.criterion = torch.nn.L2loss()
|
191 |
+
elif self.criterion_type == 'mse':
|
192 |
+
self.criterion = torch.nn.MSELoss(reduction='mean')
|
193 |
+
elif self.criterion_type == 'fro':
|
194 |
+
self.criterion = None
|
195 |
+
else:
|
196 |
+
raise NotImplementedError(f'{criterion} criterion has not been supported.')
|
197 |
+
|
198 |
+
def forward(self, x, gt):
|
199 |
+
"""Forward function.
|
200 |
+
|
201 |
+
Args:
|
202 |
+
x (Tensor): Input tensor with shape (n, c, h, w).
|
203 |
+
gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
|
204 |
+
|
205 |
+
Returns:
|
206 |
+
Tensor: Forward results.
|
207 |
+
"""
|
208 |
+
# extract vgg features
|
209 |
+
x_features = self.vgg(x)
|
210 |
+
gt_features = self.vgg(gt.detach())
|
211 |
+
|
212 |
+
# calculate perceptual loss
|
213 |
+
if self.perceptual_weight > 0:
|
214 |
+
percep_loss = 0
|
215 |
+
for k in x_features.keys():
|
216 |
+
if self.criterion_type == 'fro':
|
217 |
+
percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k]
|
218 |
+
else:
|
219 |
+
percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
|
220 |
+
percep_loss *= self.perceptual_weight
|
221 |
+
else:
|
222 |
+
percep_loss = None
|
223 |
+
|
224 |
+
# calculate style loss
|
225 |
+
if self.style_weight > 0:
|
226 |
+
style_loss = 0
|
227 |
+
for k in x_features.keys():
|
228 |
+
if self.criterion_type == 'fro':
|
229 |
+
style_loss += torch.norm(
|
230 |
+
self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k]
|
231 |
+
else:
|
232 |
+
style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat(
|
233 |
+
gt_features[k])) * self.layer_weights[k]
|
234 |
+
style_loss *= self.style_weight
|
235 |
+
else:
|
236 |
+
style_loss = None
|
237 |
+
|
238 |
+
return percep_loss, style_loss
|
239 |
+
|
240 |
+
def _gram_mat(self, x):
|
241 |
+
"""Calculate Gram matrix.
|
242 |
+
|
243 |
+
Args:
|
244 |
+
x (torch.Tensor): Tensor with shape of (n, c, h, w).
|
245 |
+
|
246 |
+
Returns:
|
247 |
+
torch.Tensor: Gram matrix.
|
248 |
+
"""
|
249 |
+
n, c, h, w = x.size()
|
250 |
+
features = x.view(n, c, w * h)
|
251 |
+
features_t = features.transpose(1, 2)
|
252 |
+
gram = features.bmm(features_t) / (c * h * w)
|
253 |
+
return gram
|
254 |
+
|
255 |
+
|
256 |
+
@LOSS_REGISTRY.register()
|
257 |
+
class LPIPSLoss(nn.Module):
|
258 |
+
def __init__(self,
|
259 |
+
loss_weight=1.0,
|
260 |
+
use_input_norm=True,
|
261 |
+
range_norm=False,):
|
262 |
+
super(LPIPSLoss, self).__init__()
|
263 |
+
self.perceptual = lpips.LPIPS(net="vgg", spatial=False).eval()
|
264 |
+
self.loss_weight = loss_weight
|
265 |
+
self.use_input_norm = use_input_norm
|
266 |
+
self.range_norm = range_norm
|
267 |
+
|
268 |
+
if self.use_input_norm:
|
269 |
+
# the mean is for image with range [0, 1]
|
270 |
+
self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
271 |
+
# the std is for image with range [0, 1]
|
272 |
+
self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
273 |
+
|
274 |
+
def forward(self, pred, target):
|
275 |
+
if self.range_norm:
|
276 |
+
pred = (pred + 1) / 2
|
277 |
+
target = (target + 1) / 2
|
278 |
+
if self.use_input_norm:
|
279 |
+
pred = (pred - self.mean) / self.std
|
280 |
+
target = (target - self.mean) / self.std
|
281 |
+
lpips_loss = self.perceptual(target.contiguous(), pred.contiguous())
|
282 |
+
return self.loss_weight * lpips_loss.mean()
|
283 |
+
|
284 |
+
|
285 |
+
@LOSS_REGISTRY.register()
|
286 |
+
class GANLoss(nn.Module):
|
287 |
+
"""Define GAN loss.
|
288 |
+
|
289 |
+
Args:
|
290 |
+
gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'.
|
291 |
+
real_label_val (float): The value for real label. Default: 1.0.
|
292 |
+
fake_label_val (float): The value for fake label. Default: 0.0.
|
293 |
+
loss_weight (float): Loss weight. Default: 1.0.
|
294 |
+
Note that loss_weight is only for generators; and it is always 1.0
|
295 |
+
for discriminators.
|
296 |
+
"""
|
297 |
+
|
298 |
+
def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):
|
299 |
+
super(GANLoss, self).__init__()
|
300 |
+
self.gan_type = gan_type
|
301 |
+
self.loss_weight = loss_weight
|
302 |
+
self.real_label_val = real_label_val
|
303 |
+
self.fake_label_val = fake_label_val
|
304 |
+
|
305 |
+
if self.gan_type == 'vanilla':
|
306 |
+
self.loss = nn.BCEWithLogitsLoss()
|
307 |
+
elif self.gan_type == 'lsgan':
|
308 |
+
self.loss = nn.MSELoss()
|
309 |
+
elif self.gan_type == 'wgan':
|
310 |
+
self.loss = self._wgan_loss
|
311 |
+
elif self.gan_type == 'wgan_softplus':
|
312 |
+
self.loss = self._wgan_softplus_loss
|
313 |
+
elif self.gan_type == 'hinge':
|
314 |
+
self.loss = nn.ReLU()
|
315 |
+
else:
|
316 |
+
raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.')
|
317 |
+
|
318 |
+
def _wgan_loss(self, input, target):
|
319 |
+
"""wgan loss.
|
320 |
+
|
321 |
+
Args:
|
322 |
+
input (Tensor): Input tensor.
|
323 |
+
target (bool): Target label.
|
324 |
+
|
325 |
+
Returns:
|
326 |
+
Tensor: wgan loss.
|
327 |
+
"""
|
328 |
+
return -input.mean() if target else input.mean()
|
329 |
+
|
330 |
+
def _wgan_softplus_loss(self, input, target):
|
331 |
+
"""wgan loss with soft plus. softplus is a smooth approximation to the
|
332 |
+
ReLU function.
|
333 |
+
|
334 |
+
In StyleGAN2, it is called:
|
335 |
+
Logistic loss for discriminator;
|
336 |
+
Non-saturating loss for generator.
|
337 |
+
|
338 |
+
Args:
|
339 |
+
input (Tensor): Input tensor.
|
340 |
+
target (bool): Target label.
|
341 |
+
|
342 |
+
Returns:
|
343 |
+
Tensor: wgan loss.
|
344 |
+
"""
|
345 |
+
return F.softplus(-input).mean() if target else F.softplus(input).mean()
|
346 |
+
|
347 |
+
def get_target_label(self, input, target_is_real):
|
348 |
+
"""Get target label.
|
349 |
+
|
350 |
+
Args:
|
351 |
+
input (Tensor): Input tensor.
|
352 |
+
target_is_real (bool): Whether the target is real or fake.
|
353 |
+
|
354 |
+
Returns:
|
355 |
+
(bool | Tensor): Target tensor. Return bool for wgan, otherwise,
|
356 |
+
return Tensor.
|
357 |
+
"""
|
358 |
+
|
359 |
+
if self.gan_type in ['wgan', 'wgan_softplus']:
|
360 |
+
return target_is_real
|
361 |
+
target_val = (self.real_label_val if target_is_real else self.fake_label_val)
|
362 |
+
return input.new_ones(input.size()) * target_val
|
363 |
+
|
364 |
+
def forward(self, input, target_is_real, is_disc=False):
|
365 |
+
"""
|
366 |
+
Args:
|
367 |
+
input (Tensor): The input for the loss module, i.e., the network
|
368 |
+
prediction.
|
369 |
+
target_is_real (bool): Whether the targe is real or fake.
|
370 |
+
is_disc (bool): Whether the loss for discriminators or not.
|
371 |
+
Default: False.
|
372 |
+
|
373 |
+
Returns:
|
374 |
+
Tensor: GAN loss value.
|
375 |
+
"""
|
376 |
+
if self.gan_type == 'hinge':
|
377 |
+
if is_disc: # for discriminators in hinge-gan
|
378 |
+
input = -input if target_is_real else input
|
379 |
+
loss = self.loss(1 + input).mean()
|
380 |
+
else: # for generators in hinge-gan
|
381 |
+
loss = -input.mean()
|
382 |
+
else: # other gan types
|
383 |
+
target_label = self.get_target_label(input, target_is_real)
|
384 |
+
loss = self.loss(input, target_label)
|
385 |
+
|
386 |
+
# loss_weight is always 1.0 for discriminators
|
387 |
+
return loss if is_disc else loss * self.loss_weight
|
388 |
+
|
389 |
+
|
390 |
+
def r1_penalty(real_pred, real_img):
|
391 |
+
"""R1 regularization for discriminator. The core idea is to
|
392 |
+
penalize the gradient on real data alone: when the
|
393 |
+
generator distribution produces the true data distribution
|
394 |
+
and the discriminator is equal to 0 on the data manifold, the
|
395 |
+
gradient penalty ensures that the discriminator cannot create
|
396 |
+
a non-zero gradient orthogonal to the data manifold without
|
397 |
+
suffering a loss in the GAN game.
|
398 |
+
|
399 |
+
Ref:
|
400 |
+
Eq. 9 in Which training methods for GANs do actually converge.
|
401 |
+
"""
|
402 |
+
grad_real = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0]
|
403 |
+
grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean()
|
404 |
+
return grad_penalty
|
405 |
+
|
406 |
+
|
407 |
+
def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
|
408 |
+
noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3])
|
409 |
+
grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0]
|
410 |
+
path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
|
411 |
+
|
412 |
+
path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
|
413 |
+
|
414 |
+
path_penalty = (path_lengths - path_mean).pow(2).mean()
|
415 |
+
|
416 |
+
return path_penalty, path_lengths.detach().mean(), path_mean.detach()
|
417 |
+
|
418 |
+
|
419 |
+
def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None):
|
420 |
+
"""Calculate gradient penalty for wgan-gp.
|
421 |
+
|
422 |
+
Args:
|
423 |
+
discriminator (nn.Module): Network for the discriminator.
|
424 |
+
real_data (Tensor): Real input data.
|
425 |
+
fake_data (Tensor): Fake input data.
|
426 |
+
weight (Tensor): Weight tensor. Default: None.
|
427 |
+
|
428 |
+
Returns:
|
429 |
+
Tensor: A tensor for gradient penalty.
|
430 |
+
"""
|
431 |
+
|
432 |
+
batch_size = real_data.size(0)
|
433 |
+
alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1))
|
434 |
+
|
435 |
+
# interpolate between real_data and fake_data
|
436 |
+
interpolates = alpha * real_data + (1. - alpha) * fake_data
|
437 |
+
interpolates = autograd.Variable(interpolates, requires_grad=True)
|
438 |
+
|
439 |
+
disc_interpolates = discriminator(interpolates)
|
440 |
+
gradients = autograd.grad(
|
441 |
+
outputs=disc_interpolates,
|
442 |
+
inputs=interpolates,
|
443 |
+
grad_outputs=torch.ones_like(disc_interpolates),
|
444 |
+
create_graph=True,
|
445 |
+
retain_graph=True,
|
446 |
+
only_inputs=True)[0]
|
447 |
+
|
448 |
+
if weight is not None:
|
449 |
+
gradients = gradients * weight
|
450 |
+
|
451 |
+
gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean()
|
452 |
+
if weight is not None:
|
453 |
+
gradients_penalty /= torch.mean(weight)
|
454 |
+
|
455 |
+
return gradients_penalty
|
blissful_tuner/codeformer/basicsr/metrics/__init__.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from copy import deepcopy
|
2 |
+
|
3 |
+
from codeformer.basicsr.utils.registry import METRIC_REGISTRY
|
4 |
+
from .psnr_ssim import calculate_psnr, calculate_ssim
|
5 |
+
|
6 |
+
__all__ = ['calculate_psnr', 'calculate_ssim']
|
7 |
+
|
8 |
+
|
9 |
+
def calculate_metric(data, opt):
|
10 |
+
"""Calculate metric from data and options.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
opt (dict): Configuration. It must constain:
|
14 |
+
type (str): Model type.
|
15 |
+
"""
|
16 |
+
opt = deepcopy(opt)
|
17 |
+
metric_type = opt.pop('type')
|
18 |
+
metric = METRIC_REGISTRY.get(metric_type)(**data, **opt)
|
19 |
+
return metric
|
blissful_tuner/codeformer/basicsr/metrics/metric_util.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
from codeformer.basicsr.utils.matlab_functions import bgr2ycbcr
|
4 |
+
|
5 |
+
|
6 |
+
def reorder_image(img, input_order='HWC'):
|
7 |
+
"""Reorder images to 'HWC' order.
|
8 |
+
|
9 |
+
If the input_order is (h, w), return (h, w, 1);
|
10 |
+
If the input_order is (c, h, w), return (h, w, c);
|
11 |
+
If the input_order is (h, w, c), return as it is.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
img (ndarray): Input image.
|
15 |
+
input_order (str): Whether the input order is 'HWC' or 'CHW'.
|
16 |
+
If the input image shape is (h, w), input_order will not have
|
17 |
+
effects. Default: 'HWC'.
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
ndarray: reordered image.
|
21 |
+
"""
|
22 |
+
|
23 |
+
if input_order not in ['HWC', 'CHW']:
|
24 |
+
raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' "'HWC' and 'CHW'")
|
25 |
+
if len(img.shape) == 2:
|
26 |
+
img = img[..., None]
|
27 |
+
if input_order == 'CHW':
|
28 |
+
img = img.transpose(1, 2, 0)
|
29 |
+
return img
|
30 |
+
|
31 |
+
|
32 |
+
def to_y_channel(img):
|
33 |
+
"""Change to Y channel of YCbCr.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
img (ndarray): Images with range [0, 255].
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
(ndarray): Images with range [0, 255] (float type) without round.
|
40 |
+
"""
|
41 |
+
img = img.astype(np.float32) / 255.
|
42 |
+
if img.ndim == 3 and img.shape[2] == 3:
|
43 |
+
img = bgr2ycbcr(img, y_only=True)
|
44 |
+
img = img[..., None]
|
45 |
+
return img * 255.
|
blissful_tuner/codeformer/basicsr/metrics/psnr_ssim.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
from codeformer.basicsr.metrics.metric_util import reorder_image, to_y_channel
|
5 |
+
from codeformer.basicsr.utils.registry import METRIC_REGISTRY
|
6 |
+
|
7 |
+
|
8 |
+
@METRIC_REGISTRY.register()
|
9 |
+
def calculate_psnr(img1, img2, crop_border, input_order='HWC', test_y_channel=False):
|
10 |
+
"""Calculate PSNR (Peak Signal-to-Noise Ratio).
|
11 |
+
|
12 |
+
Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
|
13 |
+
|
14 |
+
Args:
|
15 |
+
img1 (ndarray): Images with range [0, 255].
|
16 |
+
img2 (ndarray): Images with range [0, 255].
|
17 |
+
crop_border (int): Cropped pixels in each edge of an image. These
|
18 |
+
pixels are not involved in the PSNR calculation.
|
19 |
+
input_order (str): Whether the input order is 'HWC' or 'CHW'.
|
20 |
+
Default: 'HWC'.
|
21 |
+
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
float: psnr result.
|
25 |
+
"""
|
26 |
+
|
27 |
+
assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
|
28 |
+
if input_order not in ['HWC', 'CHW']:
|
29 |
+
raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"')
|
30 |
+
img1 = reorder_image(img1, input_order=input_order)
|
31 |
+
img2 = reorder_image(img2, input_order=input_order)
|
32 |
+
img1 = img1.astype(np.float64)
|
33 |
+
img2 = img2.astype(np.float64)
|
34 |
+
|
35 |
+
if crop_border != 0:
|
36 |
+
img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
|
37 |
+
img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
|
38 |
+
|
39 |
+
if test_y_channel:
|
40 |
+
img1 = to_y_channel(img1)
|
41 |
+
img2 = to_y_channel(img2)
|
42 |
+
|
43 |
+
mse = np.mean((img1 - img2)**2)
|
44 |
+
if mse == 0:
|
45 |
+
return float('inf')
|
46 |
+
return 20. * np.log10(255. / np.sqrt(mse))
|
47 |
+
|
48 |
+
|
49 |
+
def _ssim(img1, img2):
|
50 |
+
"""Calculate SSIM (structural similarity) for one channel images.
|
51 |
+
|
52 |
+
It is called by func:`calculate_ssim`.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
img1 (ndarray): Images with range [0, 255] with order 'HWC'.
|
56 |
+
img2 (ndarray): Images with range [0, 255] with order 'HWC'.
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
float: ssim result.
|
60 |
+
"""
|
61 |
+
|
62 |
+
C1 = (0.01 * 255)**2
|
63 |
+
C2 = (0.03 * 255)**2
|
64 |
+
|
65 |
+
img1 = img1.astype(np.float64)
|
66 |
+
img2 = img2.astype(np.float64)
|
67 |
+
kernel = cv2.getGaussianKernel(11, 1.5)
|
68 |
+
window = np.outer(kernel, kernel.transpose())
|
69 |
+
|
70 |
+
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]
|
71 |
+
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
|
72 |
+
mu1_sq = mu1**2
|
73 |
+
mu2_sq = mu2**2
|
74 |
+
mu1_mu2 = mu1 * mu2
|
75 |
+
sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
|
76 |
+
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
|
77 |
+
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
|
78 |
+
|
79 |
+
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
|
80 |
+
return ssim_map.mean()
|
81 |
+
|
82 |
+
|
83 |
+
@METRIC_REGISTRY.register()
|
84 |
+
def calculate_ssim(img1, img2, crop_border, input_order='HWC', test_y_channel=False):
|
85 |
+
"""Calculate SSIM (structural similarity).
|
86 |
+
|
87 |
+
Ref:
|
88 |
+
Image quality assessment: From error visibility to structural similarity
|
89 |
+
|
90 |
+
The results are the same as that of the official released MATLAB code in
|
91 |
+
https://ece.uwaterloo.ca/~z70wang/research/ssim/.
|
92 |
+
|
93 |
+
For three-channel images, SSIM is calculated for each channel and then
|
94 |
+
averaged.
|
95 |
+
|
96 |
+
Args:
|
97 |
+
img1 (ndarray): Images with range [0, 255].
|
98 |
+
img2 (ndarray): Images with range [0, 255].
|
99 |
+
crop_border (int): Cropped pixels in each edge of an image. These
|
100 |
+
pixels are not involved in the SSIM calculation.
|
101 |
+
input_order (str): Whether the input order is 'HWC' or 'CHW'.
|
102 |
+
Default: 'HWC'.
|
103 |
+
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
float: ssim result.
|
107 |
+
"""
|
108 |
+
|
109 |
+
assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
|
110 |
+
if input_order not in ['HWC', 'CHW']:
|
111 |
+
raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"')
|
112 |
+
img1 = reorder_image(img1, input_order=input_order)
|
113 |
+
img2 = reorder_image(img2, input_order=input_order)
|
114 |
+
img1 = img1.astype(np.float64)
|
115 |
+
img2 = img2.astype(np.float64)
|
116 |
+
|
117 |
+
if crop_border != 0:
|
118 |
+
img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
|
119 |
+
img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
|
120 |
+
|
121 |
+
if test_y_channel:
|
122 |
+
img1 = to_y_channel(img1)
|
123 |
+
img2 = to_y_channel(img2)
|
124 |
+
|
125 |
+
ssims = []
|
126 |
+
for i in range(img1.shape[2]):
|
127 |
+
ssims.append(_ssim(img1[..., i], img2[..., i]))
|
128 |
+
return np.array(ssims).mean()
|
blissful_tuner/codeformer/basicsr/models/__init__.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
from copy import deepcopy
|
3 |
+
from os import path as osp
|
4 |
+
|
5 |
+
from codeformer.basicsr.utils import get_root_logger, scandir
|
6 |
+
from codeformer.basicsr.utils.registry import MODEL_REGISTRY
|
7 |
+
|
8 |
+
__all__ = ['build_model']
|
9 |
+
|
10 |
+
# automatically scan and import model modules for registry
|
11 |
+
# scan all the files under the 'models' folder and collect files ending with
|
12 |
+
# '_model.py'
|
13 |
+
model_folder = osp.dirname(osp.abspath(__file__))
|
14 |
+
model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
|
15 |
+
# import all the model modules
|
16 |
+
_model_modules = [importlib.import_module(f'codeformer.basicsr.models.{file_name}') for file_name in model_filenames]
|
17 |
+
|
18 |
+
|
19 |
+
def build_model(opt):
|
20 |
+
"""Build model from options.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
opt (dict): Configuration. It must constain:
|
24 |
+
model_type (str): Model type.
|
25 |
+
"""
|
26 |
+
opt = deepcopy(opt)
|
27 |
+
model = MODEL_REGISTRY.get(opt['model_type'])(opt)
|
28 |
+
logger = get_root_logger()
|
29 |
+
logger.info(f'Model [{model.__class__.__name__}] is created.')
|
30 |
+
return model
|
blissful_tuner/codeformer/basicsr/models/base_model.py
ADDED
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
from collections import OrderedDict
|
5 |
+
from copy import deepcopy
|
6 |
+
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
7 |
+
|
8 |
+
from codeformer.basicsr.models import lr_scheduler as lr_scheduler
|
9 |
+
from codeformer.basicsr.utils.dist_util import master_only
|
10 |
+
|
11 |
+
logger = logging.getLogger('basicsr')
|
12 |
+
|
13 |
+
|
14 |
+
class BaseModel():
|
15 |
+
"""Base model."""
|
16 |
+
|
17 |
+
def __init__(self, opt):
|
18 |
+
self.opt = opt
|
19 |
+
self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
|
20 |
+
self.is_train = opt['is_train']
|
21 |
+
self.schedulers = []
|
22 |
+
self.optimizers = []
|
23 |
+
|
24 |
+
def feed_data(self, data):
|
25 |
+
pass
|
26 |
+
|
27 |
+
def optimize_parameters(self):
|
28 |
+
pass
|
29 |
+
|
30 |
+
def get_current_visuals(self):
|
31 |
+
pass
|
32 |
+
|
33 |
+
def save(self, epoch, current_iter):
|
34 |
+
"""Save networks and training state."""
|
35 |
+
pass
|
36 |
+
|
37 |
+
def validation(self, dataloader, current_iter, tb_logger, save_img=False):
|
38 |
+
"""Validation function.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
dataloader (torch.utils.data.DataLoader): Validation dataloader.
|
42 |
+
current_iter (int): Current iteration.
|
43 |
+
tb_logger (tensorboard logger): Tensorboard logger.
|
44 |
+
save_img (bool): Whether to save images. Default: False.
|
45 |
+
"""
|
46 |
+
if self.opt['dist']:
|
47 |
+
self.dist_validation(dataloader, current_iter, tb_logger, save_img)
|
48 |
+
else:
|
49 |
+
self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
|
50 |
+
|
51 |
+
def model_ema(self, decay=0.999):
|
52 |
+
net_g = self.get_bare_model(self.net_g)
|
53 |
+
|
54 |
+
net_g_params = dict(net_g.named_parameters())
|
55 |
+
net_g_ema_params = dict(self.net_g_ema.named_parameters())
|
56 |
+
|
57 |
+
for k in net_g_ema_params.keys():
|
58 |
+
net_g_ema_params[k].data.mul_(decay).add_(net_g_params[k].data, alpha=1 - decay)
|
59 |
+
|
60 |
+
def get_current_log(self):
|
61 |
+
return self.log_dict
|
62 |
+
|
63 |
+
def model_to_device(self, net):
|
64 |
+
"""Model to device. It also warps models with DistributedDataParallel
|
65 |
+
or DataParallel.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
net (nn.Module)
|
69 |
+
"""
|
70 |
+
net = net.to(self.device)
|
71 |
+
if self.opt['dist']:
|
72 |
+
find_unused_parameters = self.opt.get('find_unused_parameters', False)
|
73 |
+
net = DistributedDataParallel(
|
74 |
+
net, device_ids=[torch.cuda.current_device()], find_unused_parameters=find_unused_parameters)
|
75 |
+
elif self.opt['num_gpu'] > 1:
|
76 |
+
net = DataParallel(net)
|
77 |
+
return net
|
78 |
+
|
79 |
+
def get_optimizer(self, optim_type, params, lr, **kwargs):
|
80 |
+
if optim_type == 'Adam':
|
81 |
+
optimizer = torch.optim.Adam(params, lr, **kwargs)
|
82 |
+
else:
|
83 |
+
raise NotImplementedError(f'optimizer {optim_type} is not supperted yet.')
|
84 |
+
return optimizer
|
85 |
+
|
86 |
+
def setup_schedulers(self):
|
87 |
+
"""Set up schedulers."""
|
88 |
+
train_opt = self.opt['train']
|
89 |
+
scheduler_type = train_opt['scheduler'].pop('type')
|
90 |
+
if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']:
|
91 |
+
for optimizer in self.optimizers:
|
92 |
+
self.schedulers.append(lr_scheduler.MultiStepRestartLR(optimizer, **train_opt['scheduler']))
|
93 |
+
elif scheduler_type == 'CosineAnnealingRestartLR':
|
94 |
+
for optimizer in self.optimizers:
|
95 |
+
self.schedulers.append(lr_scheduler.CosineAnnealingRestartLR(optimizer, **train_opt['scheduler']))
|
96 |
+
else:
|
97 |
+
raise NotImplementedError(f'Scheduler {scheduler_type} is not implemented yet.')
|
98 |
+
|
99 |
+
def get_bare_model(self, net):
|
100 |
+
"""Get bare model, especially under wrapping with
|
101 |
+
DistributedDataParallel or DataParallel.
|
102 |
+
"""
|
103 |
+
if isinstance(net, (DataParallel, DistributedDataParallel)):
|
104 |
+
net = net.module
|
105 |
+
return net
|
106 |
+
|
107 |
+
@master_only
|
108 |
+
def print_network(self, net):
|
109 |
+
"""Print the str and parameter number of a network.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
net (nn.Module)
|
113 |
+
"""
|
114 |
+
if isinstance(net, (DataParallel, DistributedDataParallel)):
|
115 |
+
net_cls_str = (f'{net.__class__.__name__} - ' f'{net.module.__class__.__name__}')
|
116 |
+
else:
|
117 |
+
net_cls_str = f'{net.__class__.__name__}'
|
118 |
+
|
119 |
+
net = self.get_bare_model(net)
|
120 |
+
net_str = str(net)
|
121 |
+
net_params = sum(map(lambda x: x.numel(), net.parameters()))
|
122 |
+
|
123 |
+
logger.info(f'Network: {net_cls_str}, with parameters: {net_params:,d}')
|
124 |
+
logger.info(net_str)
|
125 |
+
|
126 |
+
def _set_lr(self, lr_groups_l):
|
127 |
+
"""Set learning rate for warmup.
|
128 |
+
|
129 |
+
Args:
|
130 |
+
lr_groups_l (list): List for lr_groups, each for an optimizer.
|
131 |
+
"""
|
132 |
+
for optimizer, lr_groups in zip(self.optimizers, lr_groups_l):
|
133 |
+
for param_group, lr in zip(optimizer.param_groups, lr_groups):
|
134 |
+
param_group['lr'] = lr
|
135 |
+
|
136 |
+
def _get_init_lr(self):
|
137 |
+
"""Get the initial lr, which is set by the scheduler.
|
138 |
+
"""
|
139 |
+
init_lr_groups_l = []
|
140 |
+
for optimizer in self.optimizers:
|
141 |
+
init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups])
|
142 |
+
return init_lr_groups_l
|
143 |
+
|
144 |
+
def update_learning_rate(self, current_iter, warmup_iter=-1):
|
145 |
+
"""Update learning rate.
|
146 |
+
|
147 |
+
Args:
|
148 |
+
current_iter (int): Current iteration.
|
149 |
+
warmup_iter (int): Warmup iter numbers. -1 for no warmup.
|
150 |
+
Default: -1.
|
151 |
+
"""
|
152 |
+
if current_iter > 1:
|
153 |
+
for scheduler in self.schedulers:
|
154 |
+
scheduler.step()
|
155 |
+
# set up warm-up learning rate
|
156 |
+
if current_iter < warmup_iter:
|
157 |
+
# get initial lr for each group
|
158 |
+
init_lr_g_l = self._get_init_lr()
|
159 |
+
# modify warming-up learning rates
|
160 |
+
# currently only support linearly warm up
|
161 |
+
warm_up_lr_l = []
|
162 |
+
for init_lr_g in init_lr_g_l:
|
163 |
+
warm_up_lr_l.append([v / warmup_iter * current_iter for v in init_lr_g])
|
164 |
+
# set learning rate
|
165 |
+
self._set_lr(warm_up_lr_l)
|
166 |
+
|
167 |
+
def get_current_learning_rate(self):
|
168 |
+
return [param_group['lr'] for param_group in self.optimizers[0].param_groups]
|
169 |
+
|
170 |
+
@master_only
|
171 |
+
def save_network(self, net, net_label, current_iter, param_key='params'):
|
172 |
+
"""Save networks.
|
173 |
+
|
174 |
+
Args:
|
175 |
+
net (nn.Module | list[nn.Module]): Network(s) to be saved.
|
176 |
+
net_label (str): Network label.
|
177 |
+
current_iter (int): Current iter number.
|
178 |
+
param_key (str | list[str]): The parameter key(s) to save network.
|
179 |
+
Default: 'params'.
|
180 |
+
"""
|
181 |
+
if current_iter == -1:
|
182 |
+
current_iter = 'latest'
|
183 |
+
save_filename = f'{net_label}_{current_iter}.pth'
|
184 |
+
save_path = os.path.join(self.opt['path']['models'], save_filename)
|
185 |
+
|
186 |
+
net = net if isinstance(net, list) else [net]
|
187 |
+
param_key = param_key if isinstance(param_key, list) else [param_key]
|
188 |
+
assert len(net) == len(param_key), 'The lengths of net and param_key should be the same.'
|
189 |
+
|
190 |
+
save_dict = {}
|
191 |
+
for net_, param_key_ in zip(net, param_key):
|
192 |
+
net_ = self.get_bare_model(net_)
|
193 |
+
state_dict = net_.state_dict()
|
194 |
+
for key, param in state_dict.items():
|
195 |
+
if key.startswith('module.'): # remove unnecessary 'module.'
|
196 |
+
key = key[7:]
|
197 |
+
state_dict[key] = param.cpu()
|
198 |
+
save_dict[param_key_] = state_dict
|
199 |
+
|
200 |
+
torch.save(save_dict, save_path)
|
201 |
+
|
202 |
+
def _print_different_keys_loading(self, crt_net, load_net, strict=True):
|
203 |
+
"""Print keys with differnet name or different size when loading models.
|
204 |
+
|
205 |
+
1. Print keys with differnet names.
|
206 |
+
2. If strict=False, print the same key but with different tensor size.
|
207 |
+
It also ignore these keys with different sizes (not load).
|
208 |
+
|
209 |
+
Args:
|
210 |
+
crt_net (torch model): Current network.
|
211 |
+
load_net (dict): Loaded network.
|
212 |
+
strict (bool): Whether strictly loaded. Default: True.
|
213 |
+
"""
|
214 |
+
crt_net = self.get_bare_model(crt_net)
|
215 |
+
crt_net = crt_net.state_dict()
|
216 |
+
crt_net_keys = set(crt_net.keys())
|
217 |
+
load_net_keys = set(load_net.keys())
|
218 |
+
|
219 |
+
if crt_net_keys != load_net_keys:
|
220 |
+
logger.warning('Current net - loaded net:')
|
221 |
+
for v in sorted(list(crt_net_keys - load_net_keys)):
|
222 |
+
logger.warning(f' {v}')
|
223 |
+
logger.warning('Loaded net - current net:')
|
224 |
+
for v in sorted(list(load_net_keys - crt_net_keys)):
|
225 |
+
logger.warning(f' {v}')
|
226 |
+
|
227 |
+
# check the size for the same keys
|
228 |
+
if not strict:
|
229 |
+
common_keys = crt_net_keys & load_net_keys
|
230 |
+
for k in common_keys:
|
231 |
+
if crt_net[k].size() != load_net[k].size():
|
232 |
+
logger.warning(f'Size different, ignore [{k}]: crt_net: '
|
233 |
+
f'{crt_net[k].shape}; load_net: {load_net[k].shape}')
|
234 |
+
load_net[k + '.ignore'] = load_net.pop(k)
|
235 |
+
|
236 |
+
def load_network(self, net, load_path, strict=True, param_key='params'):
|
237 |
+
"""Load network.
|
238 |
+
|
239 |
+
Args:
|
240 |
+
load_path (str): The path of networks to be loaded.
|
241 |
+
net (nn.Module): Network.
|
242 |
+
strict (bool): Whether strictly loaded.
|
243 |
+
param_key (str): The parameter key of loaded network. If set to
|
244 |
+
None, use the root 'path'.
|
245 |
+
Default: 'params'.
|
246 |
+
"""
|
247 |
+
net = self.get_bare_model(net)
|
248 |
+
logger.info(f'Loading {net.__class__.__name__} model from {load_path}.')
|
249 |
+
load_net = torch.load(load_path, map_location=lambda storage, loc: storage)
|
250 |
+
if param_key is not None:
|
251 |
+
if param_key not in load_net and 'params' in load_net:
|
252 |
+
param_key = 'params'
|
253 |
+
logger.info('Loading: params_ema does not exist, use params.')
|
254 |
+
load_net = load_net[param_key]
|
255 |
+
# remove unnecessary 'module.'
|
256 |
+
for k, v in deepcopy(load_net).items():
|
257 |
+
if k.startswith('module.'):
|
258 |
+
load_net[k[7:]] = v
|
259 |
+
load_net.pop(k)
|
260 |
+
self._print_different_keys_loading(net, load_net, strict)
|
261 |
+
net.load_state_dict(load_net, strict=strict)
|
262 |
+
|
263 |
+
@master_only
|
264 |
+
def save_training_state(self, epoch, current_iter):
|
265 |
+
"""Save training states during training, which will be used for
|
266 |
+
resuming.
|
267 |
+
|
268 |
+
Args:
|
269 |
+
epoch (int): Current epoch.
|
270 |
+
current_iter (int): Current iteration.
|
271 |
+
"""
|
272 |
+
if current_iter != -1:
|
273 |
+
state = {'epoch': epoch, 'iter': current_iter, 'optimizers': [], 'schedulers': []}
|
274 |
+
for o in self.optimizers:
|
275 |
+
state['optimizers'].append(o.state_dict())
|
276 |
+
for s in self.schedulers:
|
277 |
+
state['schedulers'].append(s.state_dict())
|
278 |
+
save_filename = f'{current_iter}.state'
|
279 |
+
save_path = os.path.join(self.opt['path']['training_states'], save_filename)
|
280 |
+
torch.save(state, save_path)
|
281 |
+
|
282 |
+
def resume_training(self, resume_state):
|
283 |
+
"""Reload the optimizers and schedulers for resumed training.
|
284 |
+
|
285 |
+
Args:
|
286 |
+
resume_state (dict): Resume state.
|
287 |
+
"""
|
288 |
+
resume_optimizers = resume_state['optimizers']
|
289 |
+
resume_schedulers = resume_state['schedulers']
|
290 |
+
assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers'
|
291 |
+
assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers'
|
292 |
+
for i, o in enumerate(resume_optimizers):
|
293 |
+
self.optimizers[i].load_state_dict(o)
|
294 |
+
for i, s in enumerate(resume_schedulers):
|
295 |
+
self.schedulers[i].load_state_dict(s)
|
296 |
+
|
297 |
+
def reduce_loss_dict(self, loss_dict):
|
298 |
+
"""reduce loss dict.
|
299 |
+
|
300 |
+
In distributed training, it averages the losses among different GPUs .
|
301 |
+
|
302 |
+
Args:
|
303 |
+
loss_dict (OrderedDict): Loss dict.
|
304 |
+
"""
|
305 |
+
with torch.no_grad():
|
306 |
+
if self.opt['dist']:
|
307 |
+
keys = []
|
308 |
+
losses = []
|
309 |
+
for name, value in loss_dict.items():
|
310 |
+
keys.append(name)
|
311 |
+
losses.append(value)
|
312 |
+
losses = torch.stack(losses, 0)
|
313 |
+
torch.distributed.reduce(losses, dst=0)
|
314 |
+
if self.opt['rank'] == 0:
|
315 |
+
losses /= self.opt['world_size']
|
316 |
+
loss_dict = {key: loss for key, loss in zip(keys, losses)}
|
317 |
+
|
318 |
+
log_dict = OrderedDict()
|
319 |
+
for name, value in loss_dict.items():
|
320 |
+
log_dict[name] = value.mean().item()
|
321 |
+
|
322 |
+
return log_dict
|
blissful_tuner/codeformer/basicsr/models/codeformer_idx_model.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from collections import OrderedDict
|
3 |
+
from os import path as osp
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
from codeformer.basicsr.archs import build_network
|
7 |
+
from codeformer.basicsr.metrics import calculate_metric
|
8 |
+
from codeformer.basicsr.utils import get_root_logger, imwrite, tensor2img
|
9 |
+
from codeformer.basicsr.utils.registry import MODEL_REGISTRY
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from .sr_model import SRModel
|
12 |
+
|
13 |
+
|
14 |
+
@MODEL_REGISTRY.register()
|
15 |
+
class CodeFormerIdxModel(SRModel):
|
16 |
+
def feed_data(self, data):
|
17 |
+
self.gt = data['gt'].to(self.device)
|
18 |
+
self.input = data['in'].to(self.device)
|
19 |
+
self.b = self.gt.shape[0]
|
20 |
+
|
21 |
+
if 'latent_gt' in data:
|
22 |
+
self.idx_gt = data['latent_gt'].to(self.device)
|
23 |
+
self.idx_gt = self.idx_gt.view(self.b, -1)
|
24 |
+
else:
|
25 |
+
self.idx_gt = None
|
26 |
+
|
27 |
+
def init_training_settings(self):
|
28 |
+
logger = get_root_logger()
|
29 |
+
train_opt = self.opt['train']
|
30 |
+
|
31 |
+
self.ema_decay = train_opt.get('ema_decay', 0)
|
32 |
+
if self.ema_decay > 0:
|
33 |
+
logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
|
34 |
+
# define network net_g with Exponential Moving Average (EMA)
|
35 |
+
# net_g_ema is used only for testing on one GPU and saving
|
36 |
+
# There is no need to wrap with DistributedDataParallel
|
37 |
+
self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
|
38 |
+
# load pretrained model
|
39 |
+
load_path = self.opt['path'].get('pretrain_network_g', None)
|
40 |
+
if load_path is not None:
|
41 |
+
self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
|
42 |
+
else:
|
43 |
+
self.model_ema(0) # copy net_g weight
|
44 |
+
self.net_g_ema.eval()
|
45 |
+
|
46 |
+
if self.opt['datasets']['train'].get('latent_gt_path', None) is not None:
|
47 |
+
self.generate_idx_gt = False
|
48 |
+
elif self.opt.get('network_vqgan', None) is not None:
|
49 |
+
self.hq_vqgan_fix = build_network(self.opt['network_vqgan']).to(self.device)
|
50 |
+
self.hq_vqgan_fix.eval()
|
51 |
+
self.generate_idx_gt = True
|
52 |
+
for param in self.hq_vqgan_fix.parameters():
|
53 |
+
param.requires_grad = False
|
54 |
+
else:
|
55 |
+
raise NotImplementedError(f'Shoule have network_vqgan config or pre-calculated latent code.')
|
56 |
+
|
57 |
+
logger.info(f'Need to generate latent GT code: {self.generate_idx_gt}')
|
58 |
+
|
59 |
+
self.hq_feat_loss = train_opt.get('use_hq_feat_loss', True)
|
60 |
+
self.feat_loss_weight = train_opt.get('feat_loss_weight', 1.0)
|
61 |
+
self.cross_entropy_loss = train_opt.get('cross_entropy_loss', True)
|
62 |
+
self.entropy_loss_weight = train_opt.get('entropy_loss_weight', 0.5)
|
63 |
+
|
64 |
+
self.net_g.train()
|
65 |
+
|
66 |
+
# set up optimizers and schedulers
|
67 |
+
self.setup_optimizers()
|
68 |
+
self.setup_schedulers()
|
69 |
+
|
70 |
+
|
71 |
+
def setup_optimizers(self):
|
72 |
+
train_opt = self.opt['train']
|
73 |
+
# optimizer g
|
74 |
+
optim_params_g = []
|
75 |
+
for k, v in self.net_g.named_parameters():
|
76 |
+
if v.requires_grad:
|
77 |
+
optim_params_g.append(v)
|
78 |
+
else:
|
79 |
+
logger = get_root_logger()
|
80 |
+
logger.warning(f'Params {k} will not be optimized.')
|
81 |
+
optim_type = train_opt['optim_g'].pop('type')
|
82 |
+
self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, **train_opt['optim_g'])
|
83 |
+
self.optimizers.append(self.optimizer_g)
|
84 |
+
|
85 |
+
|
86 |
+
def optimize_parameters(self, current_iter):
|
87 |
+
logger = get_root_logger()
|
88 |
+
# optimize net_g
|
89 |
+
self.optimizer_g.zero_grad()
|
90 |
+
|
91 |
+
if self.generate_idx_gt:
|
92 |
+
x = self.hq_vqgan_fix.encoder(self.gt)
|
93 |
+
_, _, quant_stats = self.hq_vqgan_fix.quantize(x)
|
94 |
+
min_encoding_indices = quant_stats['min_encoding_indices']
|
95 |
+
self.idx_gt = min_encoding_indices.view(self.b, -1)
|
96 |
+
|
97 |
+
if self.hq_feat_loss:
|
98 |
+
# quant_feats
|
99 |
+
quant_feat_gt = self.net_g.module.quantize.get_codebook_feat(self.idx_gt, shape=[self.b,16,16,256])
|
100 |
+
|
101 |
+
logits, lq_feat = self.net_g(self.input, w=0, code_only=True)
|
102 |
+
|
103 |
+
l_g_total = 0
|
104 |
+
loss_dict = OrderedDict()
|
105 |
+
# hq_feat_loss
|
106 |
+
if self.hq_feat_loss: # codebook loss
|
107 |
+
l_feat_encoder = torch.mean((quant_feat_gt.detach()-lq_feat)**2) * self.feat_loss_weight
|
108 |
+
l_g_total += l_feat_encoder
|
109 |
+
loss_dict['l_feat_encoder'] = l_feat_encoder
|
110 |
+
|
111 |
+
# cross_entropy_loss
|
112 |
+
if self.cross_entropy_loss:
|
113 |
+
# b(hw)n -> bn(hw)
|
114 |
+
cross_entropy_loss = F.cross_entropy(logits.permute(0, 2, 1), self.idx_gt) * self.entropy_loss_weight
|
115 |
+
l_g_total += cross_entropy_loss
|
116 |
+
loss_dict['cross_entropy_loss'] = cross_entropy_loss
|
117 |
+
|
118 |
+
l_g_total.backward()
|
119 |
+
self.optimizer_g.step()
|
120 |
+
|
121 |
+
if self.ema_decay > 0:
|
122 |
+
self.model_ema(decay=self.ema_decay)
|
123 |
+
|
124 |
+
self.log_dict = self.reduce_loss_dict(loss_dict)
|
125 |
+
|
126 |
+
|
127 |
+
def test(self):
|
128 |
+
with torch.no_grad():
|
129 |
+
if hasattr(self, 'net_g_ema'):
|
130 |
+
self.net_g_ema.eval()
|
131 |
+
self.output, _, _ = self.net_g_ema(self.input, w=0)
|
132 |
+
else:
|
133 |
+
logger = get_root_logger()
|
134 |
+
logger.warning('Do not have self.net_g_ema, use self.net_g.')
|
135 |
+
self.net_g.eval()
|
136 |
+
self.output, _, _ = self.net_g(self.input, w=0)
|
137 |
+
self.net_g.train()
|
138 |
+
|
139 |
+
|
140 |
+
def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
141 |
+
if self.opt['rank'] == 0:
|
142 |
+
self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
|
143 |
+
|
144 |
+
|
145 |
+
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
146 |
+
dataset_name = dataloader.dataset.opt['name']
|
147 |
+
with_metrics = self.opt['val'].get('metrics') is not None
|
148 |
+
if with_metrics:
|
149 |
+
self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
|
150 |
+
pbar = tqdm(total=len(dataloader), unit='image')
|
151 |
+
|
152 |
+
for idx, val_data in enumerate(dataloader):
|
153 |
+
img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
|
154 |
+
self.feed_data(val_data)
|
155 |
+
self.test()
|
156 |
+
|
157 |
+
visuals = self.get_current_visuals()
|
158 |
+
sr_img = tensor2img([visuals['result']])
|
159 |
+
if 'gt' in visuals:
|
160 |
+
gt_img = tensor2img([visuals['gt']])
|
161 |
+
del self.gt
|
162 |
+
|
163 |
+
# tentative for out of GPU memory
|
164 |
+
del self.lq
|
165 |
+
del self.output
|
166 |
+
torch.cuda.empty_cache()
|
167 |
+
|
168 |
+
if save_img:
|
169 |
+
if self.opt['is_train']:
|
170 |
+
save_img_path = osp.join(self.opt['path']['visualization'], img_name,
|
171 |
+
f'{img_name}_{current_iter}.png')
|
172 |
+
else:
|
173 |
+
if self.opt['val']['suffix']:
|
174 |
+
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
|
175 |
+
f'{img_name}_{self.opt["val"]["suffix"]}.png')
|
176 |
+
else:
|
177 |
+
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
|
178 |
+
f'{img_name}_{self.opt["name"]}.png')
|
179 |
+
imwrite(sr_img, save_img_path)
|
180 |
+
|
181 |
+
if with_metrics:
|
182 |
+
# calculate metrics
|
183 |
+
for name, opt_ in self.opt['val']['metrics'].items():
|
184 |
+
metric_data = dict(img1=sr_img, img2=gt_img)
|
185 |
+
self.metric_results[name] += calculate_metric(metric_data, opt_)
|
186 |
+
pbar.update(1)
|
187 |
+
pbar.set_description(f'Test {img_name}')
|
188 |
+
pbar.close()
|
189 |
+
|
190 |
+
if with_metrics:
|
191 |
+
for metric in self.metric_results.keys():
|
192 |
+
self.metric_results[metric] /= (idx + 1)
|
193 |
+
|
194 |
+
self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
|
195 |
+
|
196 |
+
|
197 |
+
def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
|
198 |
+
log_str = f'Validation {dataset_name}\n'
|
199 |
+
for metric, value in self.metric_results.items():
|
200 |
+
log_str += f'\t # {metric}: {value:.4f}\n'
|
201 |
+
logger = get_root_logger()
|
202 |
+
logger.info(log_str)
|
203 |
+
if tb_logger:
|
204 |
+
for metric, value in self.metric_results.items():
|
205 |
+
tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
|
206 |
+
|
207 |
+
|
208 |
+
def get_current_visuals(self):
|
209 |
+
out_dict = OrderedDict()
|
210 |
+
out_dict['gt'] = self.gt.detach().cpu()
|
211 |
+
out_dict['result'] = self.output.detach().cpu()
|
212 |
+
return out_dict
|
213 |
+
|
214 |
+
|
215 |
+
def save(self, epoch, current_iter):
|
216 |
+
if self.ema_decay > 0:
|
217 |
+
self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
|
218 |
+
else:
|
219 |
+
self.save_network(self.net_g, 'net_g', current_iter)
|
220 |
+
self.save_training_state(epoch, current_iter)
|
blissful_tuner/codeformer/basicsr/models/codeformer_joint_model.py
ADDED
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from collections import OrderedDict
|
3 |
+
from os import path as osp
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
|
7 |
+
from codeformer.basicsr.archs import build_network
|
8 |
+
from codeformer.basicsr.losses import build_loss
|
9 |
+
from codeformer.basicsr.metrics import calculate_metric
|
10 |
+
from codeformer.basicsr.utils import get_root_logger, imwrite, tensor2img
|
11 |
+
from codeformer.basicsr.utils.registry import MODEL_REGISTRY
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from .sr_model import SRModel
|
14 |
+
|
15 |
+
|
16 |
+
@MODEL_REGISTRY.register()
|
17 |
+
class CodeFormerJointModel(SRModel):
|
18 |
+
def feed_data(self, data):
|
19 |
+
self.gt = data['gt'].to(self.device)
|
20 |
+
self.input = data['in'].to(self.device)
|
21 |
+
self.input_large_de = data['in_large_de'].to(self.device)
|
22 |
+
self.b = self.gt.shape[0]
|
23 |
+
|
24 |
+
if 'latent_gt' in data:
|
25 |
+
self.idx_gt = data['latent_gt'].to(self.device)
|
26 |
+
self.idx_gt = self.idx_gt.view(self.b, -1)
|
27 |
+
else:
|
28 |
+
self.idx_gt = None
|
29 |
+
|
30 |
+
def init_training_settings(self):
|
31 |
+
logger = get_root_logger()
|
32 |
+
train_opt = self.opt['train']
|
33 |
+
|
34 |
+
self.ema_decay = train_opt.get('ema_decay', 0)
|
35 |
+
if self.ema_decay > 0:
|
36 |
+
logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
|
37 |
+
# define network net_g with Exponential Moving Average (EMA)
|
38 |
+
# net_g_ema is used only for testing on one GPU and saving
|
39 |
+
# There is no need to wrap with DistributedDataParallel
|
40 |
+
self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
|
41 |
+
# load pretrained model
|
42 |
+
load_path = self.opt['path'].get('pretrain_network_g', None)
|
43 |
+
if load_path is not None:
|
44 |
+
self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
|
45 |
+
else:
|
46 |
+
self.model_ema(0) # copy net_g weight
|
47 |
+
self.net_g_ema.eval()
|
48 |
+
|
49 |
+
if self.opt['datasets']['train'].get('latent_gt_path', None) is not None:
|
50 |
+
self.generate_idx_gt = False
|
51 |
+
elif self.opt.get('network_vqgan', None) is not None:
|
52 |
+
self.hq_vqgan_fix = build_network(self.opt['network_vqgan']).to(self.device)
|
53 |
+
self.hq_vqgan_fix.eval()
|
54 |
+
self.generate_idx_gt = True
|
55 |
+
for param in self.hq_vqgan_fix.parameters():
|
56 |
+
param.requires_grad = False
|
57 |
+
else:
|
58 |
+
raise NotImplementedError(f'Shoule have network_vqgan config or pre-calculated latent code.')
|
59 |
+
|
60 |
+
logger.info(f'Need to generate latent GT code: {self.generate_idx_gt}')
|
61 |
+
|
62 |
+
self.hq_feat_loss = train_opt.get('use_hq_feat_loss', True)
|
63 |
+
self.feat_loss_weight = train_opt.get('feat_loss_weight', 1.0)
|
64 |
+
self.cross_entropy_loss = train_opt.get('cross_entropy_loss', True)
|
65 |
+
self.entropy_loss_weight = train_opt.get('entropy_loss_weight', 0.5)
|
66 |
+
self.scale_adaptive_gan_weight = train_opt.get('scale_adaptive_gan_weight', 0.8)
|
67 |
+
|
68 |
+
# define network net_d
|
69 |
+
self.net_d = build_network(self.opt['network_d'])
|
70 |
+
self.net_d = self.model_to_device(self.net_d)
|
71 |
+
self.print_network(self.net_d)
|
72 |
+
|
73 |
+
# load pretrained models
|
74 |
+
load_path = self.opt['path'].get('pretrain_network_d', None)
|
75 |
+
if load_path is not None:
|
76 |
+
self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True))
|
77 |
+
|
78 |
+
self.net_g.train()
|
79 |
+
self.net_d.train()
|
80 |
+
|
81 |
+
# define losses
|
82 |
+
if train_opt.get('pixel_opt'):
|
83 |
+
self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
|
84 |
+
else:
|
85 |
+
self.cri_pix = None
|
86 |
+
|
87 |
+
if train_opt.get('perceptual_opt'):
|
88 |
+
self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
|
89 |
+
else:
|
90 |
+
self.cri_perceptual = None
|
91 |
+
|
92 |
+
if train_opt.get('gan_opt'):
|
93 |
+
self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
|
94 |
+
|
95 |
+
|
96 |
+
self.fix_generator = train_opt.get('fix_generator', True)
|
97 |
+
logger.info(f'fix_generator: {self.fix_generator}')
|
98 |
+
|
99 |
+
self.net_g_start_iter = train_opt.get('net_g_start_iter', 0)
|
100 |
+
self.net_d_iters = train_opt.get('net_d_iters', 1)
|
101 |
+
self.net_d_start_iter = train_opt.get('net_d_start_iter', 0)
|
102 |
+
|
103 |
+
# set up optimizers and schedulers
|
104 |
+
self.setup_optimizers()
|
105 |
+
self.setup_schedulers()
|
106 |
+
|
107 |
+
def calculate_adaptive_weight(self, recon_loss, g_loss, last_layer, disc_weight_max):
|
108 |
+
recon_grads = torch.autograd.grad(recon_loss, last_layer, retain_graph=True)[0]
|
109 |
+
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
110 |
+
|
111 |
+
d_weight = torch.norm(recon_grads) / (torch.norm(g_grads) + 1e-4)
|
112 |
+
d_weight = torch.clamp(d_weight, 0.0, disc_weight_max).detach()
|
113 |
+
return d_weight
|
114 |
+
|
115 |
+
def setup_optimizers(self):
|
116 |
+
train_opt = self.opt['train']
|
117 |
+
# optimizer g
|
118 |
+
optim_params_g = []
|
119 |
+
for k, v in self.net_g.named_parameters():
|
120 |
+
if v.requires_grad:
|
121 |
+
optim_params_g.append(v)
|
122 |
+
else:
|
123 |
+
logger = get_root_logger()
|
124 |
+
logger.warning(f'Params {k} will not be optimized.')
|
125 |
+
optim_type = train_opt['optim_g'].pop('type')
|
126 |
+
self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, **train_opt['optim_g'])
|
127 |
+
self.optimizers.append(self.optimizer_g)
|
128 |
+
# optimizer d
|
129 |
+
optim_type = train_opt['optim_d'].pop('type')
|
130 |
+
self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d'])
|
131 |
+
self.optimizers.append(self.optimizer_d)
|
132 |
+
|
133 |
+
def gray_resize_for_identity(self, out, size=128):
|
134 |
+
out_gray = (0.2989 * out[:, 0, :, :] + 0.5870 * out[:, 1, :, :] + 0.1140 * out[:, 2, :, :])
|
135 |
+
out_gray = out_gray.unsqueeze(1)
|
136 |
+
out_gray = F.interpolate(out_gray, (size, size), mode='bilinear', align_corners=False)
|
137 |
+
return out_gray
|
138 |
+
|
139 |
+
def optimize_parameters(self, current_iter):
|
140 |
+
logger = get_root_logger()
|
141 |
+
# optimize net_g
|
142 |
+
for p in self.net_d.parameters():
|
143 |
+
p.requires_grad = False
|
144 |
+
|
145 |
+
self.optimizer_g.zero_grad()
|
146 |
+
|
147 |
+
if self.generate_idx_gt:
|
148 |
+
x = self.hq_vqgan_fix.encoder(self.gt)
|
149 |
+
output, _, quant_stats = self.hq_vqgan_fix.quantize(x)
|
150 |
+
min_encoding_indices = quant_stats['min_encoding_indices']
|
151 |
+
self.idx_gt = min_encoding_indices.view(self.b, -1)
|
152 |
+
|
153 |
+
if current_iter <= 40000: # small degradation
|
154 |
+
small_per_n = 1
|
155 |
+
w = 1
|
156 |
+
elif current_iter <= 80000: # small degradation
|
157 |
+
small_per_n = 1
|
158 |
+
w = 1.3
|
159 |
+
elif current_iter <= 120000: # large degradation
|
160 |
+
small_per_n = 120000
|
161 |
+
w = 0
|
162 |
+
else: # mixed degradation
|
163 |
+
small_per_n = 15
|
164 |
+
w = 1.3
|
165 |
+
|
166 |
+
if current_iter % small_per_n == 0:
|
167 |
+
self.output, logits, lq_feat = self.net_g(self.input, w=w, detach_16=True)
|
168 |
+
large_de = False
|
169 |
+
else:
|
170 |
+
logits, lq_feat = self.net_g(self.input_large_de, code_only=True)
|
171 |
+
large_de = True
|
172 |
+
|
173 |
+
if self.hq_feat_loss:
|
174 |
+
# quant_feats
|
175 |
+
quant_feat_gt = self.net_g.module.quantize.get_codebook_feat(self.idx_gt, shape=[self.b,16,16,256])
|
176 |
+
|
177 |
+
l_g_total = 0
|
178 |
+
loss_dict = OrderedDict()
|
179 |
+
if current_iter % self.net_d_iters == 0 and current_iter > self.net_g_start_iter:
|
180 |
+
# hq_feat_loss
|
181 |
+
if not 'transformer' in self.opt['network_g']['fix_modules']:
|
182 |
+
if self.hq_feat_loss: # codebook loss
|
183 |
+
l_feat_encoder = torch.mean((quant_feat_gt.detach()-lq_feat)**2) * self.feat_loss_weight
|
184 |
+
l_g_total += l_feat_encoder
|
185 |
+
loss_dict['l_feat_encoder'] = l_feat_encoder
|
186 |
+
|
187 |
+
# cross_entropy_loss
|
188 |
+
if self.cross_entropy_loss:
|
189 |
+
# b(hw)n -> bn(hw)
|
190 |
+
cross_entropy_loss = F.cross_entropy(logits.permute(0, 2, 1), self.idx_gt) * self.entropy_loss_weight
|
191 |
+
l_g_total += cross_entropy_loss
|
192 |
+
loss_dict['cross_entropy_loss'] = cross_entropy_loss
|
193 |
+
|
194 |
+
# pixel loss
|
195 |
+
if not large_de: # when large degradation don't need image-level loss
|
196 |
+
if self.cri_pix:
|
197 |
+
l_g_pix = self.cri_pix(self.output, self.gt)
|
198 |
+
l_g_total += l_g_pix
|
199 |
+
loss_dict['l_g_pix'] = l_g_pix
|
200 |
+
|
201 |
+
# perceptual loss
|
202 |
+
if self.cri_perceptual:
|
203 |
+
l_g_percep = self.cri_perceptual(self.output, self.gt)
|
204 |
+
l_g_total += l_g_percep
|
205 |
+
loss_dict['l_g_percep'] = l_g_percep
|
206 |
+
|
207 |
+
# gan loss
|
208 |
+
if current_iter > self.net_d_start_iter:
|
209 |
+
fake_g_pred = self.net_d(self.output)
|
210 |
+
l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
|
211 |
+
recon_loss = l_g_pix + l_g_percep
|
212 |
+
if not self.fix_generator:
|
213 |
+
last_layer = self.net_g.module.generator.blocks[-1].weight
|
214 |
+
d_weight = self.calculate_adaptive_weight(recon_loss, l_g_gan, last_layer, disc_weight_max=1.0)
|
215 |
+
else:
|
216 |
+
largest_fuse_size = self.opt['network_g']['connect_list'][-1]
|
217 |
+
last_layer = self.net_g.module.fuse_convs_dict[largest_fuse_size].shift[-1].weight
|
218 |
+
d_weight = self.calculate_adaptive_weight(recon_loss, l_g_gan, last_layer, disc_weight_max=1.0)
|
219 |
+
|
220 |
+
d_weight *= self.scale_adaptive_gan_weight # 0.8
|
221 |
+
loss_dict['d_weight'] = d_weight
|
222 |
+
l_g_total += d_weight * l_g_gan
|
223 |
+
loss_dict['l_g_gan'] = d_weight * l_g_gan
|
224 |
+
|
225 |
+
l_g_total.backward()
|
226 |
+
self.optimizer_g.step()
|
227 |
+
|
228 |
+
if self.ema_decay > 0:
|
229 |
+
self.model_ema(decay=self.ema_decay)
|
230 |
+
|
231 |
+
# optimize net_d
|
232 |
+
if not large_de:
|
233 |
+
if current_iter > self.net_d_start_iter:
|
234 |
+
for p in self.net_d.parameters():
|
235 |
+
p.requires_grad = True
|
236 |
+
|
237 |
+
self.optimizer_d.zero_grad()
|
238 |
+
# real
|
239 |
+
real_d_pred = self.net_d(self.gt)
|
240 |
+
l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
|
241 |
+
loss_dict['l_d_real'] = l_d_real
|
242 |
+
loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
|
243 |
+
l_d_real.backward()
|
244 |
+
# fake
|
245 |
+
fake_d_pred = self.net_d(self.output.detach())
|
246 |
+
l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
|
247 |
+
loss_dict['l_d_fake'] = l_d_fake
|
248 |
+
loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
|
249 |
+
l_d_fake.backward()
|
250 |
+
|
251 |
+
self.optimizer_d.step()
|
252 |
+
|
253 |
+
self.log_dict = self.reduce_loss_dict(loss_dict)
|
254 |
+
|
255 |
+
|
256 |
+
def test(self):
|
257 |
+
with torch.no_grad():
|
258 |
+
if hasattr(self, 'net_g_ema'):
|
259 |
+
self.net_g_ema.eval()
|
260 |
+
self.output, _, _ = self.net_g_ema(self.input, w=1)
|
261 |
+
else:
|
262 |
+
logger = get_root_logger()
|
263 |
+
logger.warning('Do not have self.net_g_ema, use self.net_g.')
|
264 |
+
self.net_g.eval()
|
265 |
+
self.output, _, _ = self.net_g(self.input, w=1)
|
266 |
+
self.net_g.train()
|
267 |
+
|
268 |
+
|
269 |
+
def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
270 |
+
if self.opt['rank'] == 0:
|
271 |
+
self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
|
272 |
+
|
273 |
+
|
274 |
+
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
275 |
+
dataset_name = dataloader.dataset.opt['name']
|
276 |
+
with_metrics = self.opt['val'].get('metrics') is not None
|
277 |
+
if with_metrics:
|
278 |
+
self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
|
279 |
+
pbar = tqdm(total=len(dataloader), unit='image')
|
280 |
+
|
281 |
+
for idx, val_data in enumerate(dataloader):
|
282 |
+
img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
|
283 |
+
self.feed_data(val_data)
|
284 |
+
self.test()
|
285 |
+
|
286 |
+
visuals = self.get_current_visuals()
|
287 |
+
sr_img = tensor2img([visuals['result']])
|
288 |
+
if 'gt' in visuals:
|
289 |
+
gt_img = tensor2img([visuals['gt']])
|
290 |
+
del self.gt
|
291 |
+
|
292 |
+
# tentative for out of GPU memory
|
293 |
+
del self.lq
|
294 |
+
del self.output
|
295 |
+
torch.cuda.empty_cache()
|
296 |
+
|
297 |
+
if save_img:
|
298 |
+
if self.opt['is_train']:
|
299 |
+
save_img_path = osp.join(self.opt['path']['visualization'], img_name,
|
300 |
+
f'{img_name}_{current_iter}.png')
|
301 |
+
else:
|
302 |
+
if self.opt['val']['suffix']:
|
303 |
+
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
|
304 |
+
f'{img_name}_{self.opt["val"]["suffix"]}.png')
|
305 |
+
else:
|
306 |
+
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
|
307 |
+
f'{img_name}_{self.opt["name"]}.png')
|
308 |
+
imwrite(sr_img, save_img_path)
|
309 |
+
|
310 |
+
if with_metrics:
|
311 |
+
# calculate metrics
|
312 |
+
for name, opt_ in self.opt['val']['metrics'].items():
|
313 |
+
metric_data = dict(img1=sr_img, img2=gt_img)
|
314 |
+
self.metric_results[name] += calculate_metric(metric_data, opt_)
|
315 |
+
pbar.update(1)
|
316 |
+
pbar.set_description(f'Test {img_name}')
|
317 |
+
pbar.close()
|
318 |
+
|
319 |
+
if with_metrics:
|
320 |
+
for metric in self.metric_results.keys():
|
321 |
+
self.metric_results[metric] /= (idx + 1)
|
322 |
+
|
323 |
+
self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
|
324 |
+
|
325 |
+
|
326 |
+
def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
|
327 |
+
log_str = f'Validation {dataset_name}\n'
|
328 |
+
for metric, value in self.metric_results.items():
|
329 |
+
log_str += f'\t # {metric}: {value:.4f}\n'
|
330 |
+
logger = get_root_logger()
|
331 |
+
logger.info(log_str)
|
332 |
+
if tb_logger:
|
333 |
+
for metric, value in self.metric_results.items():
|
334 |
+
tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
|
335 |
+
|
336 |
+
|
337 |
+
def get_current_visuals(self):
|
338 |
+
out_dict = OrderedDict()
|
339 |
+
out_dict['gt'] = self.gt.detach().cpu()
|
340 |
+
out_dict['result'] = self.output.detach().cpu()
|
341 |
+
return out_dict
|
342 |
+
|
343 |
+
|
344 |
+
def save(self, epoch, current_iter):
|
345 |
+
if self.ema_decay > 0:
|
346 |
+
self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
|
347 |
+
else:
|
348 |
+
self.save_network(self.net_g, 'net_g', current_iter)
|
349 |
+
self.save_network(self.net_d, 'net_d', current_iter)
|
350 |
+
self.save_training_state(epoch, current_iter)
|
blissful_tuner/codeformer/basicsr/models/codeformer_model.py
ADDED
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from collections import OrderedDict
|
3 |
+
from os import path as osp
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
from codeformer.basicsr.archs import build_network
|
7 |
+
from codeformer.basicsr.losses import build_loss
|
8 |
+
from codeformer.basicsr.metrics import calculate_metric
|
9 |
+
from codeformer.basicsr.utils import get_root_logger, imwrite, tensor2img
|
10 |
+
from codeformer.basicsr.utils.registry import MODEL_REGISTRY
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from .sr_model import SRModel
|
13 |
+
|
14 |
+
|
15 |
+
@MODEL_REGISTRY.register()
|
16 |
+
class CodeFormerModel(SRModel):
|
17 |
+
def feed_data(self, data):
|
18 |
+
self.gt = data['gt'].to(self.device)
|
19 |
+
self.input = data['in'].to(self.device)
|
20 |
+
self.b = self.gt.shape[0]
|
21 |
+
|
22 |
+
if 'latent_gt' in data:
|
23 |
+
self.idx_gt = data['latent_gt'].to(self.device)
|
24 |
+
self.idx_gt = self.idx_gt.view(self.b, -1)
|
25 |
+
else:
|
26 |
+
self.idx_gt = None
|
27 |
+
|
28 |
+
def init_training_settings(self):
|
29 |
+
logger = get_root_logger()
|
30 |
+
train_opt = self.opt['train']
|
31 |
+
|
32 |
+
self.ema_decay = train_opt.get('ema_decay', 0)
|
33 |
+
if self.ema_decay > 0:
|
34 |
+
logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
|
35 |
+
# define network net_g with Exponential Moving Average (EMA)
|
36 |
+
# net_g_ema is used only for testing on one GPU and saving
|
37 |
+
# There is no need to wrap with DistributedDataParallel
|
38 |
+
self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
|
39 |
+
# load pretrained model
|
40 |
+
load_path = self.opt['path'].get('pretrain_network_g', None)
|
41 |
+
if load_path is not None:
|
42 |
+
self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
|
43 |
+
else:
|
44 |
+
self.model_ema(0) # copy net_g weight
|
45 |
+
self.net_g_ema.eval()
|
46 |
+
|
47 |
+
if self.opt.get('network_vqgan', None) is not None and self.opt['datasets'].get('latent_gt_path') is None:
|
48 |
+
self.hq_vqgan_fix = build_network(self.opt['network_vqgan']).to(self.device)
|
49 |
+
self.hq_vqgan_fix.eval()
|
50 |
+
self.generate_idx_gt = True
|
51 |
+
for param in self.hq_vqgan_fix.parameters():
|
52 |
+
param.requires_grad = False
|
53 |
+
else:
|
54 |
+
self.generate_idx_gt = False
|
55 |
+
|
56 |
+
self.hq_feat_loss = train_opt.get('use_hq_feat_loss', True)
|
57 |
+
self.feat_loss_weight = train_opt.get('feat_loss_weight', 1.0)
|
58 |
+
self.cross_entropy_loss = train_opt.get('cross_entropy_loss', True)
|
59 |
+
self.entropy_loss_weight = train_opt.get('entropy_loss_weight', 0.5)
|
60 |
+
self.fidelity_weight = train_opt.get('fidelity_weight', 1.0)
|
61 |
+
self.scale_adaptive_gan_weight = train_opt.get('scale_adaptive_gan_weight', 0.8)
|
62 |
+
|
63 |
+
|
64 |
+
self.net_g.train()
|
65 |
+
# define network net_d
|
66 |
+
if self.fidelity_weight > 0:
|
67 |
+
self.net_d = build_network(self.opt['network_d'])
|
68 |
+
self.net_d = self.model_to_device(self.net_d)
|
69 |
+
self.print_network(self.net_d)
|
70 |
+
|
71 |
+
# load pretrained models
|
72 |
+
load_path = self.opt['path'].get('pretrain_network_d', None)
|
73 |
+
if load_path is not None:
|
74 |
+
self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True))
|
75 |
+
|
76 |
+
self.net_d.train()
|
77 |
+
|
78 |
+
# define losses
|
79 |
+
if train_opt.get('pixel_opt'):
|
80 |
+
self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
|
81 |
+
else:
|
82 |
+
self.cri_pix = None
|
83 |
+
|
84 |
+
if train_opt.get('perceptual_opt'):
|
85 |
+
self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
|
86 |
+
else:
|
87 |
+
self.cri_perceptual = None
|
88 |
+
|
89 |
+
if train_opt.get('gan_opt'):
|
90 |
+
self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
|
91 |
+
|
92 |
+
|
93 |
+
self.fix_generator = train_opt.get('fix_generator', True)
|
94 |
+
logger.info(f'fix_generator: {self.fix_generator}')
|
95 |
+
|
96 |
+
self.net_g_start_iter = train_opt.get('net_g_start_iter', 0)
|
97 |
+
self.net_d_iters = train_opt.get('net_d_iters', 1)
|
98 |
+
self.net_d_start_iter = train_opt.get('net_d_start_iter', 0)
|
99 |
+
|
100 |
+
# set up optimizers and schedulers
|
101 |
+
self.setup_optimizers()
|
102 |
+
self.setup_schedulers()
|
103 |
+
|
104 |
+
def calculate_adaptive_weight(self, recon_loss, g_loss, last_layer, disc_weight_max):
|
105 |
+
recon_grads = torch.autograd.grad(recon_loss, last_layer, retain_graph=True)[0]
|
106 |
+
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
107 |
+
|
108 |
+
d_weight = torch.norm(recon_grads) / (torch.norm(g_grads) + 1e-4)
|
109 |
+
d_weight = torch.clamp(d_weight, 0.0, disc_weight_max).detach()
|
110 |
+
return d_weight
|
111 |
+
|
112 |
+
def setup_optimizers(self):
|
113 |
+
train_opt = self.opt['train']
|
114 |
+
# optimizer g
|
115 |
+
optim_params_g = []
|
116 |
+
for k, v in self.net_g.named_parameters():
|
117 |
+
if v.requires_grad:
|
118 |
+
optim_params_g.append(v)
|
119 |
+
else:
|
120 |
+
logger = get_root_logger()
|
121 |
+
logger.warning(f'Params {k} will not be optimized.')
|
122 |
+
optim_type = train_opt['optim_g'].pop('type')
|
123 |
+
self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, **train_opt['optim_g'])
|
124 |
+
self.optimizers.append(self.optimizer_g)
|
125 |
+
# optimizer d
|
126 |
+
if self.fidelity_weight > 0:
|
127 |
+
optim_type = train_opt['optim_d'].pop('type')
|
128 |
+
self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d'])
|
129 |
+
self.optimizers.append(self.optimizer_d)
|
130 |
+
|
131 |
+
def gray_resize_for_identity(self, out, size=128):
|
132 |
+
out_gray = (0.2989 * out[:, 0, :, :] + 0.5870 * out[:, 1, :, :] + 0.1140 * out[:, 2, :, :])
|
133 |
+
out_gray = out_gray.unsqueeze(1)
|
134 |
+
out_gray = F.interpolate(out_gray, (size, size), mode='bilinear', align_corners=False)
|
135 |
+
return out_gray
|
136 |
+
|
137 |
+
def optimize_parameters(self, current_iter):
|
138 |
+
logger = get_root_logger()
|
139 |
+
# optimize net_g
|
140 |
+
for p in self.net_d.parameters():
|
141 |
+
p.requires_grad = False
|
142 |
+
|
143 |
+
self.optimizer_g.zero_grad()
|
144 |
+
|
145 |
+
if self.generate_idx_gt:
|
146 |
+
x = self.hq_vqgan_fix.encoder(self.gt)
|
147 |
+
output, _, quant_stats = self.hq_vqgan_fix.quantize(x)
|
148 |
+
min_encoding_indices = quant_stats['min_encoding_indices']
|
149 |
+
self.idx_gt = min_encoding_indices.view(self.b, -1)
|
150 |
+
|
151 |
+
if self.fidelity_weight > 0:
|
152 |
+
self.output, logits, lq_feat = self.net_g(self.input, w=self.fidelity_weight, detach_16=True)
|
153 |
+
else:
|
154 |
+
logits, lq_feat = self.net_g(self.input, w=0, code_only=True)
|
155 |
+
|
156 |
+
if self.hq_feat_loss:
|
157 |
+
# quant_feats
|
158 |
+
quant_feat_gt = self.net_g.module.quantize.get_codebook_feat(self.idx_gt, shape=[self.b,16,16,256])
|
159 |
+
|
160 |
+
l_g_total = 0
|
161 |
+
loss_dict = OrderedDict()
|
162 |
+
if current_iter % self.net_d_iters == 0 and current_iter > self.net_g_start_iter:
|
163 |
+
# hq_feat_loss
|
164 |
+
if self.hq_feat_loss: # codebook loss
|
165 |
+
l_feat_encoder = torch.mean((quant_feat_gt.detach()-lq_feat)**2) * self.feat_loss_weight
|
166 |
+
l_g_total += l_feat_encoder
|
167 |
+
loss_dict['l_feat_encoder'] = l_feat_encoder
|
168 |
+
|
169 |
+
# cross_entropy_loss
|
170 |
+
if self.cross_entropy_loss:
|
171 |
+
# b(hw)n -> bn(hw)
|
172 |
+
cross_entropy_loss = F.cross_entropy(logits.permute(0, 2, 1), self.idx_gt) * self.entropy_loss_weight
|
173 |
+
l_g_total += cross_entropy_loss
|
174 |
+
loss_dict['cross_entropy_loss'] = cross_entropy_loss
|
175 |
+
|
176 |
+
if self.fidelity_weight > 0: # when fidelity_weight == 0 don't need image-level loss
|
177 |
+
# pixel loss
|
178 |
+
if self.cri_pix:
|
179 |
+
l_g_pix = self.cri_pix(self.output, self.gt)
|
180 |
+
l_g_total += l_g_pix
|
181 |
+
loss_dict['l_g_pix'] = l_g_pix
|
182 |
+
|
183 |
+
# perceptual loss
|
184 |
+
if self.cri_perceptual:
|
185 |
+
l_g_percep = self.cri_perceptual(self.output, self.gt)
|
186 |
+
l_g_total += l_g_percep
|
187 |
+
loss_dict['l_g_percep'] = l_g_percep
|
188 |
+
|
189 |
+
# gan loss
|
190 |
+
if current_iter > self.net_d_start_iter:
|
191 |
+
fake_g_pred = self.net_d(self.output)
|
192 |
+
l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
|
193 |
+
recon_loss = l_g_pix + l_g_percep
|
194 |
+
if not self.fix_generator:
|
195 |
+
last_layer = self.net_g.module.generator.blocks[-1].weight
|
196 |
+
d_weight = self.calculate_adaptive_weight(recon_loss, l_g_gan, last_layer, disc_weight_max=1.0)
|
197 |
+
else:
|
198 |
+
largest_fuse_size = self.opt['network_g']['connect_list'][-1]
|
199 |
+
last_layer = self.net_g.module.fuse_convs_dict[largest_fuse_size].shift[-1].weight
|
200 |
+
d_weight = self.calculate_adaptive_weight(recon_loss, l_g_gan, last_layer, disc_weight_max=1.0)
|
201 |
+
|
202 |
+
d_weight *= self.scale_adaptive_gan_weight # 0.8
|
203 |
+
loss_dict['d_weight'] = d_weight
|
204 |
+
l_g_total += d_weight * l_g_gan
|
205 |
+
loss_dict['l_g_gan'] = d_weight * l_g_gan
|
206 |
+
|
207 |
+
l_g_total.backward()
|
208 |
+
self.optimizer_g.step()
|
209 |
+
|
210 |
+
if self.ema_decay > 0:
|
211 |
+
self.model_ema(decay=self.ema_decay)
|
212 |
+
|
213 |
+
# optimize net_d
|
214 |
+
if current_iter > self.net_d_start_iter and self.fidelity_weight > 0:
|
215 |
+
for p in self.net_d.parameters():
|
216 |
+
p.requires_grad = True
|
217 |
+
|
218 |
+
self.optimizer_d.zero_grad()
|
219 |
+
# real
|
220 |
+
real_d_pred = self.net_d(self.gt)
|
221 |
+
l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
|
222 |
+
loss_dict['l_d_real'] = l_d_real
|
223 |
+
loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
|
224 |
+
l_d_real.backward()
|
225 |
+
# fake
|
226 |
+
fake_d_pred = self.net_d(self.output.detach())
|
227 |
+
l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
|
228 |
+
loss_dict['l_d_fake'] = l_d_fake
|
229 |
+
loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
|
230 |
+
l_d_fake.backward()
|
231 |
+
|
232 |
+
self.optimizer_d.step()
|
233 |
+
|
234 |
+
self.log_dict = self.reduce_loss_dict(loss_dict)
|
235 |
+
|
236 |
+
|
237 |
+
def test(self):
|
238 |
+
with torch.no_grad():
|
239 |
+
if hasattr(self, 'net_g_ema'):
|
240 |
+
self.net_g_ema.eval()
|
241 |
+
self.output, _, _ = self.net_g_ema(self.input, w=self.fidelity_weight)
|
242 |
+
else:
|
243 |
+
logger = get_root_logger()
|
244 |
+
logger.warning('Do not have self.net_g_ema, use self.net_g.')
|
245 |
+
self.net_g.eval()
|
246 |
+
self.output, _, _ = self.net_g(self.input, w=self.fidelity_weight)
|
247 |
+
self.net_g.train()
|
248 |
+
|
249 |
+
|
250 |
+
def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
251 |
+
if self.opt['rank'] == 0:
|
252 |
+
self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
|
253 |
+
|
254 |
+
|
255 |
+
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
256 |
+
dataset_name = dataloader.dataset.opt['name']
|
257 |
+
with_metrics = self.opt['val'].get('metrics') is not None
|
258 |
+
if with_metrics:
|
259 |
+
self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
|
260 |
+
pbar = tqdm(total=len(dataloader), unit='image')
|
261 |
+
|
262 |
+
for idx, val_data in enumerate(dataloader):
|
263 |
+
img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
|
264 |
+
self.feed_data(val_data)
|
265 |
+
self.test()
|
266 |
+
|
267 |
+
visuals = self.get_current_visuals()
|
268 |
+
sr_img = tensor2img([visuals['result']])
|
269 |
+
if 'gt' in visuals:
|
270 |
+
gt_img = tensor2img([visuals['gt']])
|
271 |
+
del self.gt
|
272 |
+
|
273 |
+
# tentative for out of GPU memory
|
274 |
+
del self.lq
|
275 |
+
del self.output
|
276 |
+
torch.cuda.empty_cache()
|
277 |
+
|
278 |
+
if save_img:
|
279 |
+
if self.opt['is_train']:
|
280 |
+
save_img_path = osp.join(self.opt['path']['visualization'], img_name,
|
281 |
+
f'{img_name}_{current_iter}.png')
|
282 |
+
else:
|
283 |
+
if self.opt['val']['suffix']:
|
284 |
+
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
|
285 |
+
f'{img_name}_{self.opt["val"]["suffix"]}.png')
|
286 |
+
else:
|
287 |
+
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
|
288 |
+
f'{img_name}_{self.opt["name"]}.png')
|
289 |
+
imwrite(sr_img, save_img_path)
|
290 |
+
|
291 |
+
if with_metrics:
|
292 |
+
# calculate metrics
|
293 |
+
for name, opt_ in self.opt['val']['metrics'].items():
|
294 |
+
metric_data = dict(img1=sr_img, img2=gt_img)
|
295 |
+
self.metric_results[name] += calculate_metric(metric_data, opt_)
|
296 |
+
pbar.update(1)
|
297 |
+
pbar.set_description(f'Test {img_name}')
|
298 |
+
pbar.close()
|
299 |
+
|
300 |
+
if with_metrics:
|
301 |
+
for metric in self.metric_results.keys():
|
302 |
+
self.metric_results[metric] /= (idx + 1)
|
303 |
+
|
304 |
+
self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
|
305 |
+
|
306 |
+
|
307 |
+
def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
|
308 |
+
log_str = f'Validation {dataset_name}\n'
|
309 |
+
for metric, value in self.metric_results.items():
|
310 |
+
log_str += f'\t # {metric}: {value:.4f}\n'
|
311 |
+
logger = get_root_logger()
|
312 |
+
logger.info(log_str)
|
313 |
+
if tb_logger:
|
314 |
+
for metric, value in self.metric_results.items():
|
315 |
+
tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
|
316 |
+
|
317 |
+
|
318 |
+
def get_current_visuals(self):
|
319 |
+
out_dict = OrderedDict()
|
320 |
+
out_dict['gt'] = self.gt.detach().cpu()
|
321 |
+
out_dict['result'] = self.output.detach().cpu()
|
322 |
+
return out_dict
|
323 |
+
|
324 |
+
|
325 |
+
def save(self, epoch, current_iter):
|
326 |
+
if self.ema_decay > 0:
|
327 |
+
self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
|
328 |
+
else:
|
329 |
+
self.save_network(self.net_g, 'net_g', current_iter)
|
330 |
+
if self.fidelity_weight > 0:
|
331 |
+
self.save_network(self.net_d, 'net_d', current_iter)
|
332 |
+
self.save_training_state(epoch, current_iter)
|
blissful_tuner/codeformer/basicsr/models/lr_scheduler.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from collections import Counter
|
3 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
4 |
+
|
5 |
+
|
6 |
+
class MultiStepRestartLR(_LRScheduler):
|
7 |
+
""" MultiStep with restarts learning rate scheme.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
optimizer (torch.nn.optimizer): Torch optimizer.
|
11 |
+
milestones (list): Iterations that will decrease learning rate.
|
12 |
+
gamma (float): Decrease ratio. Default: 0.1.
|
13 |
+
restarts (list): Restart iterations. Default: [0].
|
14 |
+
restart_weights (list): Restart weights at each restart iteration.
|
15 |
+
Default: [1].
|
16 |
+
last_epoch (int): Used in _LRScheduler. Default: -1.
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, optimizer, milestones, gamma=0.1, restarts=(0, ), restart_weights=(1, ), last_epoch=-1):
|
20 |
+
self.milestones = Counter(milestones)
|
21 |
+
self.gamma = gamma
|
22 |
+
self.restarts = restarts
|
23 |
+
self.restart_weights = restart_weights
|
24 |
+
assert len(self.restarts) == len(self.restart_weights), 'restarts and their weights do not match.'
|
25 |
+
super(MultiStepRestartLR, self).__init__(optimizer, last_epoch)
|
26 |
+
|
27 |
+
def get_lr(self):
|
28 |
+
if self.last_epoch in self.restarts:
|
29 |
+
weight = self.restart_weights[self.restarts.index(self.last_epoch)]
|
30 |
+
return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
|
31 |
+
if self.last_epoch not in self.milestones:
|
32 |
+
return [group['lr'] for group in self.optimizer.param_groups]
|
33 |
+
return [group['lr'] * self.gamma**self.milestones[self.last_epoch] for group in self.optimizer.param_groups]
|
34 |
+
|
35 |
+
|
36 |
+
def get_position_from_periods(iteration, cumulative_period):
|
37 |
+
"""Get the position from a period list.
|
38 |
+
|
39 |
+
It will return the index of the right-closest number in the period list.
|
40 |
+
For example, the cumulative_period = [100, 200, 300, 400],
|
41 |
+
if iteration == 50, return 0;
|
42 |
+
if iteration == 210, return 2;
|
43 |
+
if iteration == 300, return 2.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
iteration (int): Current iteration.
|
47 |
+
cumulative_period (list[int]): Cumulative period list.
|
48 |
+
|
49 |
+
Returns:
|
50 |
+
int: The position of the right-closest number in the period list.
|
51 |
+
"""
|
52 |
+
for i, period in enumerate(cumulative_period):
|
53 |
+
if iteration <= period:
|
54 |
+
return i
|
55 |
+
|
56 |
+
|
57 |
+
class CosineAnnealingRestartLR(_LRScheduler):
|
58 |
+
""" Cosine annealing with restarts learning rate scheme.
|
59 |
+
|
60 |
+
An example of config:
|
61 |
+
periods = [10, 10, 10, 10]
|
62 |
+
restart_weights = [1, 0.5, 0.5, 0.5]
|
63 |
+
eta_min=1e-7
|
64 |
+
|
65 |
+
It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the
|
66 |
+
scheduler will restart with the weights in restart_weights.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
optimizer (torch.nn.optimizer): Torch optimizer.
|
70 |
+
periods (list): Period for each cosine anneling cycle.
|
71 |
+
restart_weights (list): Restart weights at each restart iteration.
|
72 |
+
Default: [1].
|
73 |
+
eta_min (float): The mimimum lr. Default: 0.
|
74 |
+
last_epoch (int): Used in _LRScheduler. Default: -1.
|
75 |
+
"""
|
76 |
+
|
77 |
+
def __init__(self, optimizer, periods, restart_weights=(1, ), eta_min=0, last_epoch=-1):
|
78 |
+
self.periods = periods
|
79 |
+
self.restart_weights = restart_weights
|
80 |
+
self.eta_min = eta_min
|
81 |
+
assert (len(self.periods) == len(
|
82 |
+
self.restart_weights)), 'periods and restart_weights should have the same length.'
|
83 |
+
self.cumulative_period = [sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))]
|
84 |
+
super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch)
|
85 |
+
|
86 |
+
def get_lr(self):
|
87 |
+
idx = get_position_from_periods(self.last_epoch, self.cumulative_period)
|
88 |
+
current_weight = self.restart_weights[idx]
|
89 |
+
nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1]
|
90 |
+
current_period = self.periods[idx]
|
91 |
+
|
92 |
+
return [
|
93 |
+
self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) *
|
94 |
+
(1 + math.cos(math.pi * ((self.last_epoch - nearest_restart) / current_period)))
|
95 |
+
for base_lr in self.base_lrs
|
96 |
+
]
|
blissful_tuner/codeformer/basicsr/models/sr_model.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from collections import OrderedDict
|
3 |
+
from os import path as osp
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
from codeformer.basicsr.archs import build_network
|
7 |
+
from codeformer.basicsr.losses import build_loss
|
8 |
+
from codeformer.basicsr.metrics import calculate_metric
|
9 |
+
from codeformer.basicsr.utils import get_root_logger, imwrite, tensor2img
|
10 |
+
from codeformer.basicsr.utils.registry import MODEL_REGISTRY
|
11 |
+
from .base_model import BaseModel
|
12 |
+
|
13 |
+
@MODEL_REGISTRY.register()
|
14 |
+
class SRModel(BaseModel):
|
15 |
+
"""Base SR model for single image super-resolution."""
|
16 |
+
|
17 |
+
def __init__(self, opt):
|
18 |
+
super(SRModel, self).__init__(opt)
|
19 |
+
|
20 |
+
# define network
|
21 |
+
self.net_g = build_network(opt['network_g'])
|
22 |
+
self.net_g = self.model_to_device(self.net_g)
|
23 |
+
self.print_network(self.net_g)
|
24 |
+
|
25 |
+
# load pretrained models
|
26 |
+
load_path = self.opt['path'].get('pretrain_network_g', None)
|
27 |
+
if load_path is not None:
|
28 |
+
param_key = self.opt['path'].get('param_key_g', 'params')
|
29 |
+
self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key)
|
30 |
+
|
31 |
+
if self.is_train:
|
32 |
+
self.init_training_settings()
|
33 |
+
|
34 |
+
def init_training_settings(self):
|
35 |
+
self.net_g.train()
|
36 |
+
train_opt = self.opt['train']
|
37 |
+
|
38 |
+
self.ema_decay = train_opt.get('ema_decay', 0)
|
39 |
+
if self.ema_decay > 0:
|
40 |
+
logger = get_root_logger()
|
41 |
+
logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
|
42 |
+
# define network net_g with Exponential Moving Average (EMA)
|
43 |
+
# net_g_ema is used only for testing on one GPU and saving
|
44 |
+
# There is no need to wrap with DistributedDataParallel
|
45 |
+
self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
|
46 |
+
# load pretrained model
|
47 |
+
load_path = self.opt['path'].get('pretrain_network_g', None)
|
48 |
+
if load_path is not None:
|
49 |
+
self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
|
50 |
+
else:
|
51 |
+
self.model_ema(0) # copy net_g weight
|
52 |
+
self.net_g_ema.eval()
|
53 |
+
|
54 |
+
# define losses
|
55 |
+
if train_opt.get('pixel_opt'):
|
56 |
+
self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
|
57 |
+
else:
|
58 |
+
self.cri_pix = None
|
59 |
+
|
60 |
+
if train_opt.get('perceptual_opt'):
|
61 |
+
self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
|
62 |
+
else:
|
63 |
+
self.cri_perceptual = None
|
64 |
+
|
65 |
+
if self.cri_pix is None and self.cri_perceptual is None:
|
66 |
+
raise ValueError('Both pixel and perceptual losses are None.')
|
67 |
+
|
68 |
+
# set up optimizers and schedulers
|
69 |
+
self.setup_optimizers()
|
70 |
+
self.setup_schedulers()
|
71 |
+
|
72 |
+
def setup_optimizers(self):
|
73 |
+
train_opt = self.opt['train']
|
74 |
+
optim_params = []
|
75 |
+
for k, v in self.net_g.named_parameters():
|
76 |
+
if v.requires_grad:
|
77 |
+
optim_params.append(v)
|
78 |
+
else:
|
79 |
+
logger = get_root_logger()
|
80 |
+
logger.warning(f'Params {k} will not be optimized.')
|
81 |
+
|
82 |
+
optim_type = train_opt['optim_g'].pop('type')
|
83 |
+
self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g'])
|
84 |
+
self.optimizers.append(self.optimizer_g)
|
85 |
+
|
86 |
+
def feed_data(self, data):
|
87 |
+
self.lq = data['lq'].to(self.device)
|
88 |
+
if 'gt' in data:
|
89 |
+
self.gt = data['gt'].to(self.device)
|
90 |
+
|
91 |
+
def optimize_parameters(self, current_iter):
|
92 |
+
self.optimizer_g.zero_grad()
|
93 |
+
self.output = self.net_g(self.lq)
|
94 |
+
|
95 |
+
l_total = 0
|
96 |
+
loss_dict = OrderedDict()
|
97 |
+
# pixel loss
|
98 |
+
if self.cri_pix:
|
99 |
+
l_pix = self.cri_pix(self.output, self.gt)
|
100 |
+
l_total += l_pix
|
101 |
+
loss_dict['l_pix'] = l_pix
|
102 |
+
# perceptual loss
|
103 |
+
if self.cri_perceptual:
|
104 |
+
l_percep, l_style = self.cri_perceptual(self.output, self.gt)
|
105 |
+
if l_percep is not None:
|
106 |
+
l_total += l_percep
|
107 |
+
loss_dict['l_percep'] = l_percep
|
108 |
+
if l_style is not None:
|
109 |
+
l_total += l_style
|
110 |
+
loss_dict['l_style'] = l_style
|
111 |
+
|
112 |
+
l_total.backward()
|
113 |
+
self.optimizer_g.step()
|
114 |
+
|
115 |
+
self.log_dict = self.reduce_loss_dict(loss_dict)
|
116 |
+
|
117 |
+
if self.ema_decay > 0:
|
118 |
+
self.model_ema(decay=self.ema_decay)
|
119 |
+
|
120 |
+
def test(self):
|
121 |
+
if hasattr(self, 'ema_decay'):
|
122 |
+
self.net_g_ema.eval()
|
123 |
+
with torch.no_grad():
|
124 |
+
self.output = self.net_g_ema(self.lq)
|
125 |
+
else:
|
126 |
+
self.net_g.eval()
|
127 |
+
with torch.no_grad():
|
128 |
+
self.output = self.net_g(self.lq)
|
129 |
+
self.net_g.train()
|
130 |
+
|
131 |
+
def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
132 |
+
if self.opt['rank'] == 0:
|
133 |
+
self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
|
134 |
+
|
135 |
+
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
136 |
+
dataset_name = dataloader.dataset.opt['name']
|
137 |
+
with_metrics = self.opt['val'].get('metrics') is not None
|
138 |
+
if with_metrics:
|
139 |
+
self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
|
140 |
+
pbar = tqdm(total=len(dataloader), unit='image')
|
141 |
+
|
142 |
+
for idx, val_data in enumerate(dataloader):
|
143 |
+
img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
|
144 |
+
self.feed_data(val_data)
|
145 |
+
self.test()
|
146 |
+
|
147 |
+
visuals = self.get_current_visuals()
|
148 |
+
sr_img = tensor2img([visuals['result']])
|
149 |
+
if 'gt' in visuals:
|
150 |
+
gt_img = tensor2img([visuals['gt']])
|
151 |
+
del self.gt
|
152 |
+
|
153 |
+
# tentative for out of GPU memory
|
154 |
+
del self.lq
|
155 |
+
del self.output
|
156 |
+
torch.cuda.empty_cache()
|
157 |
+
|
158 |
+
if save_img:
|
159 |
+
if self.opt['is_train']:
|
160 |
+
save_img_path = osp.join(self.opt['path']['visualization'], img_name,
|
161 |
+
f'{img_name}_{current_iter}.png')
|
162 |
+
else:
|
163 |
+
if self.opt['val']['suffix']:
|
164 |
+
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
|
165 |
+
f'{img_name}_{self.opt["val"]["suffix"]}.png')
|
166 |
+
else:
|
167 |
+
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
|
168 |
+
f'{img_name}_{self.opt["name"]}.png')
|
169 |
+
imwrite(sr_img, save_img_path)
|
170 |
+
|
171 |
+
if with_metrics:
|
172 |
+
# calculate metrics
|
173 |
+
for name, opt_ in self.opt['val']['metrics'].items():
|
174 |
+
metric_data = dict(img1=sr_img, img2=gt_img)
|
175 |
+
self.metric_results[name] += calculate_metric(metric_data, opt_)
|
176 |
+
pbar.update(1)
|
177 |
+
pbar.set_description(f'Test {img_name}')
|
178 |
+
pbar.close()
|
179 |
+
|
180 |
+
if with_metrics:
|
181 |
+
for metric in self.metric_results.keys():
|
182 |
+
self.metric_results[metric] /= (idx + 1)
|
183 |
+
|
184 |
+
self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
|
185 |
+
|
186 |
+
def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
|
187 |
+
log_str = f'Validation {dataset_name}\n'
|
188 |
+
for metric, value in self.metric_results.items():
|
189 |
+
log_str += f'\t # {metric}: {value:.4f}\n'
|
190 |
+
logger = get_root_logger()
|
191 |
+
logger.info(log_str)
|
192 |
+
if tb_logger:
|
193 |
+
for metric, value in self.metric_results.items():
|
194 |
+
tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
|
195 |
+
|
196 |
+
def get_current_visuals(self):
|
197 |
+
out_dict = OrderedDict()
|
198 |
+
out_dict['lq'] = self.lq.detach().cpu()
|
199 |
+
out_dict['result'] = self.output.detach().cpu()
|
200 |
+
if hasattr(self, 'gt'):
|
201 |
+
out_dict['gt'] = self.gt.detach().cpu()
|
202 |
+
return out_dict
|
203 |
+
|
204 |
+
def save(self, epoch, current_iter):
|
205 |
+
if hasattr(self, 'ema_decay'):
|
206 |
+
self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
|
207 |
+
else:
|
208 |
+
self.save_network(self.net_g, 'net_g', current_iter)
|
209 |
+
self.save_training_state(epoch, current_iter)
|
blissful_tuner/codeformer/basicsr/models/vqgan_model.py
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from collections import OrderedDict
|
3 |
+
from os import path as osp
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
from codeformer.basicsr.archs import build_network
|
7 |
+
from codeformer.basicsr.losses import build_loss
|
8 |
+
from codeformer.basicsr.metrics import calculate_metric
|
9 |
+
from codeformer.basicsr.utils import get_root_logger, imwrite, tensor2img
|
10 |
+
from codeformer.basicsr.utils.registry import MODEL_REGISTRY
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from .sr_model import SRModel
|
13 |
+
|
14 |
+
|
15 |
+
@MODEL_REGISTRY.register()
|
16 |
+
class VQGANModel(SRModel):
|
17 |
+
def feed_data(self, data):
|
18 |
+
self.gt = data['gt'].to(self.device)
|
19 |
+
self.b = self.gt.shape[0]
|
20 |
+
|
21 |
+
|
22 |
+
def init_training_settings(self):
|
23 |
+
logger = get_root_logger()
|
24 |
+
train_opt = self.opt['train']
|
25 |
+
|
26 |
+
self.ema_decay = train_opt.get('ema_decay', 0)
|
27 |
+
if self.ema_decay > 0:
|
28 |
+
logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
|
29 |
+
# define network net_g with Exponential Moving Average (EMA)
|
30 |
+
# net_g_ema is used only for testing on one GPU and saving
|
31 |
+
# There is no need to wrap with DistributedDataParallel
|
32 |
+
self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
|
33 |
+
# load pretrained model
|
34 |
+
load_path = self.opt['path'].get('pretrain_network_g', None)
|
35 |
+
if load_path is not None:
|
36 |
+
self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
|
37 |
+
else:
|
38 |
+
self.model_ema(0) # copy net_g weight
|
39 |
+
self.net_g_ema.eval()
|
40 |
+
|
41 |
+
# define network net_d
|
42 |
+
self.net_d = build_network(self.opt['network_d'])
|
43 |
+
self.net_d = self.model_to_device(self.net_d)
|
44 |
+
self.print_network(self.net_d)
|
45 |
+
|
46 |
+
# load pretrained models
|
47 |
+
load_path = self.opt['path'].get('pretrain_network_d', None)
|
48 |
+
if load_path is not None:
|
49 |
+
self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True))
|
50 |
+
|
51 |
+
self.net_g.train()
|
52 |
+
self.net_d.train()
|
53 |
+
|
54 |
+
# define losses
|
55 |
+
if train_opt.get('pixel_opt'):
|
56 |
+
self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
|
57 |
+
else:
|
58 |
+
self.cri_pix = None
|
59 |
+
|
60 |
+
if train_opt.get('perceptual_opt'):
|
61 |
+
self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
|
62 |
+
else:
|
63 |
+
self.cri_perceptual = None
|
64 |
+
|
65 |
+
if train_opt.get('gan_opt'):
|
66 |
+
self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
|
67 |
+
|
68 |
+
if train_opt.get('codebook_opt'):
|
69 |
+
self.l_weight_codebook = train_opt['codebook_opt'].get('loss_weight', 1.0)
|
70 |
+
else:
|
71 |
+
self.l_weight_codebook = 1.0
|
72 |
+
|
73 |
+
self.vqgan_quantizer = self.opt['network_g']['quantizer']
|
74 |
+
logger.info(f'vqgan_quantizer: {self.vqgan_quantizer}')
|
75 |
+
|
76 |
+
self.net_g_start_iter = train_opt.get('net_g_start_iter', 0)
|
77 |
+
self.net_d_iters = train_opt.get('net_d_iters', 1)
|
78 |
+
self.net_d_start_iter = train_opt.get('net_d_start_iter', 0)
|
79 |
+
self.disc_weight = train_opt.get('disc_weight', 0.8)
|
80 |
+
|
81 |
+
# set up optimizers and schedulers
|
82 |
+
self.setup_optimizers()
|
83 |
+
self.setup_schedulers()
|
84 |
+
|
85 |
+
def calculate_adaptive_weight(self, recon_loss, g_loss, last_layer, disc_weight_max):
|
86 |
+
recon_grads = torch.autograd.grad(recon_loss, last_layer, retain_graph=True)[0]
|
87 |
+
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
88 |
+
|
89 |
+
d_weight = torch.norm(recon_grads) / (torch.norm(g_grads) + 1e-4)
|
90 |
+
d_weight = torch.clamp(d_weight, 0.0, disc_weight_max).detach()
|
91 |
+
return d_weight
|
92 |
+
|
93 |
+
def adopt_weight(self, weight, global_step, threshold=0, value=0.):
|
94 |
+
if global_step < threshold:
|
95 |
+
weight = value
|
96 |
+
return weight
|
97 |
+
|
98 |
+
def setup_optimizers(self):
|
99 |
+
train_opt = self.opt['train']
|
100 |
+
# optimizer g
|
101 |
+
optim_params_g = []
|
102 |
+
for k, v in self.net_g.named_parameters():
|
103 |
+
if v.requires_grad:
|
104 |
+
optim_params_g.append(v)
|
105 |
+
else:
|
106 |
+
logger = get_root_logger()
|
107 |
+
logger.warning(f'Params {k} will not be optimized.')
|
108 |
+
optim_type = train_opt['optim_g'].pop('type')
|
109 |
+
self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, **train_opt['optim_g'])
|
110 |
+
self.optimizers.append(self.optimizer_g)
|
111 |
+
# optimizer d
|
112 |
+
optim_type = train_opt['optim_d'].pop('type')
|
113 |
+
self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d'])
|
114 |
+
self.optimizers.append(self.optimizer_d)
|
115 |
+
|
116 |
+
|
117 |
+
def optimize_parameters(self, current_iter):
|
118 |
+
logger = get_root_logger()
|
119 |
+
loss_dict = OrderedDict()
|
120 |
+
if self.opt['network_g']['quantizer'] == 'gumbel':
|
121 |
+
self.net_g.module.quantize.temperature = max(1/16, ((-1/160000) * current_iter) + 1)
|
122 |
+
if current_iter%1000 == 0:
|
123 |
+
logger.info(f'temperature: {self.net_g.module.quantize.temperature}')
|
124 |
+
|
125 |
+
# optimize net_g
|
126 |
+
for p in self.net_d.parameters():
|
127 |
+
p.requires_grad = False
|
128 |
+
|
129 |
+
self.optimizer_g.zero_grad()
|
130 |
+
self.output, l_codebook, quant_stats = self.net_g(self.gt)
|
131 |
+
|
132 |
+
l_codebook = l_codebook*self.l_weight_codebook
|
133 |
+
|
134 |
+
l_g_total = 0
|
135 |
+
if current_iter % self.net_d_iters == 0 and current_iter > self.net_g_start_iter:
|
136 |
+
# pixel loss
|
137 |
+
if self.cri_pix:
|
138 |
+
l_g_pix = self.cri_pix(self.output, self.gt)
|
139 |
+
l_g_total += l_g_pix
|
140 |
+
loss_dict['l_g_pix'] = l_g_pix
|
141 |
+
# perceptual loss
|
142 |
+
if self.cri_perceptual:
|
143 |
+
l_g_percep = self.cri_perceptual(self.output, self.gt)
|
144 |
+
l_g_total += l_g_percep
|
145 |
+
loss_dict['l_g_percep'] = l_g_percep
|
146 |
+
|
147 |
+
# gan loss
|
148 |
+
if current_iter > self.net_d_start_iter:
|
149 |
+
# fake_g_pred = self.net_d(self.output_1024)
|
150 |
+
fake_g_pred = self.net_d(self.output)
|
151 |
+
l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
|
152 |
+
recon_loss = l_g_total
|
153 |
+
last_layer = self.net_g.module.generator.blocks[-1].weight
|
154 |
+
d_weight = self.calculate_adaptive_weight(recon_loss, l_g_gan, last_layer, disc_weight_max=1.0)
|
155 |
+
d_weight *= self.adopt_weight(1, current_iter, self.net_d_start_iter)
|
156 |
+
d_weight *= self.disc_weight # tamming setting 0.8
|
157 |
+
l_g_total += d_weight * l_g_gan
|
158 |
+
loss_dict['l_g_gan'] = d_weight * l_g_gan
|
159 |
+
|
160 |
+
l_g_total += l_codebook
|
161 |
+
loss_dict['l_codebook'] = l_codebook
|
162 |
+
|
163 |
+
l_g_total.backward()
|
164 |
+
self.optimizer_g.step()
|
165 |
+
|
166 |
+
# optimize net_d
|
167 |
+
if current_iter > self.net_d_start_iter:
|
168 |
+
for p in self.net_d.parameters():
|
169 |
+
p.requires_grad = True
|
170 |
+
|
171 |
+
self.optimizer_d.zero_grad()
|
172 |
+
# real
|
173 |
+
real_d_pred = self.net_d(self.gt)
|
174 |
+
l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
|
175 |
+
loss_dict['l_d_real'] = l_d_real
|
176 |
+
loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
|
177 |
+
l_d_real.backward()
|
178 |
+
# fake
|
179 |
+
fake_d_pred = self.net_d(self.output.detach())
|
180 |
+
l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
|
181 |
+
loss_dict['l_d_fake'] = l_d_fake
|
182 |
+
loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
|
183 |
+
l_d_fake.backward()
|
184 |
+
self.optimizer_d.step()
|
185 |
+
|
186 |
+
self.log_dict = self.reduce_loss_dict(loss_dict)
|
187 |
+
|
188 |
+
if self.ema_decay > 0:
|
189 |
+
self.model_ema(decay=self.ema_decay)
|
190 |
+
|
191 |
+
|
192 |
+
def test(self):
|
193 |
+
with torch.no_grad():
|
194 |
+
if hasattr(self, 'net_g_ema'):
|
195 |
+
self.net_g_ema.eval()
|
196 |
+
self.output, _, _ = self.net_g_ema(self.gt)
|
197 |
+
else:
|
198 |
+
logger = get_root_logger()
|
199 |
+
logger.warning('Do not have self.net_g_ema, use self.net_g.')
|
200 |
+
self.net_g.eval()
|
201 |
+
self.output, _, _ = self.net_g(self.gt)
|
202 |
+
self.net_g.train()
|
203 |
+
|
204 |
+
|
205 |
+
def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
206 |
+
if self.opt['rank'] == 0:
|
207 |
+
self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
|
208 |
+
|
209 |
+
|
210 |
+
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
211 |
+
dataset_name = dataloader.dataset.opt['name']
|
212 |
+
with_metrics = self.opt['val'].get('metrics') is not None
|
213 |
+
if with_metrics:
|
214 |
+
self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
|
215 |
+
pbar = tqdm(total=len(dataloader), unit='image')
|
216 |
+
|
217 |
+
for idx, val_data in enumerate(dataloader):
|
218 |
+
img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
|
219 |
+
self.feed_data(val_data)
|
220 |
+
self.test()
|
221 |
+
|
222 |
+
visuals = self.get_current_visuals()
|
223 |
+
sr_img = tensor2img([visuals['result']])
|
224 |
+
if 'gt' in visuals:
|
225 |
+
gt_img = tensor2img([visuals['gt']])
|
226 |
+
del self.gt
|
227 |
+
|
228 |
+
# tentative for out of GPU memory
|
229 |
+
del self.lq
|
230 |
+
del self.output
|
231 |
+
torch.cuda.empty_cache()
|
232 |
+
|
233 |
+
if save_img:
|
234 |
+
if self.opt['is_train']:
|
235 |
+
save_img_path = osp.join(self.opt['path']['visualization'], img_name,
|
236 |
+
f'{img_name}_{current_iter}.png')
|
237 |
+
else:
|
238 |
+
if self.opt['val']['suffix']:
|
239 |
+
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
|
240 |
+
f'{img_name}_{self.opt["val"]["suffix"]}.png')
|
241 |
+
else:
|
242 |
+
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
|
243 |
+
f'{img_name}_{self.opt["name"]}.png')
|
244 |
+
imwrite(sr_img, save_img_path)
|
245 |
+
|
246 |
+
if with_metrics:
|
247 |
+
# calculate metrics
|
248 |
+
for name, opt_ in self.opt['val']['metrics'].items():
|
249 |
+
metric_data = dict(img1=sr_img, img2=gt_img)
|
250 |
+
self.metric_results[name] += calculate_metric(metric_data, opt_)
|
251 |
+
pbar.update(1)
|
252 |
+
pbar.set_description(f'Test {img_name}')
|
253 |
+
pbar.close()
|
254 |
+
|
255 |
+
if with_metrics:
|
256 |
+
for metric in self.metric_results.keys():
|
257 |
+
self.metric_results[metric] /= (idx + 1)
|
258 |
+
|
259 |
+
self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
|
260 |
+
|
261 |
+
|
262 |
+
def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
|
263 |
+
log_str = f'Validation {dataset_name}\n'
|
264 |
+
for metric, value in self.metric_results.items():
|
265 |
+
log_str += f'\t # {metric}: {value:.4f}\n'
|
266 |
+
logger = get_root_logger()
|
267 |
+
logger.info(log_str)
|
268 |
+
if tb_logger:
|
269 |
+
for metric, value in self.metric_results.items():
|
270 |
+
tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
|
271 |
+
|
272 |
+
|
273 |
+
def get_current_visuals(self):
|
274 |
+
out_dict = OrderedDict()
|
275 |
+
out_dict['gt'] = self.gt.detach().cpu()
|
276 |
+
out_dict['result'] = self.output.detach().cpu()
|
277 |
+
return out_dict
|
278 |
+
|
279 |
+
def save(self, epoch, current_iter):
|
280 |
+
if self.ema_decay > 0:
|
281 |
+
self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
|
282 |
+
else:
|
283 |
+
self.save_network(self.net_g, 'net_g', current_iter)
|
284 |
+
self.save_network(self.net_d, 'net_d', current_iter)
|
285 |
+
self.save_training_state(epoch, current_iter)
|
blissful_tuner/codeformer/basicsr/ops/__init__.py
ADDED
File without changes
|
blissful_tuner/codeformer/basicsr/ops/dcn/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack, deform_conv,
|
2 |
+
modulated_deform_conv)
|
3 |
+
|
4 |
+
__all__ = [
|
5 |
+
'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv',
|
6 |
+
'modulated_deform_conv'
|
7 |
+
]
|
blissful_tuner/codeformer/basicsr/ops/dcn/deform_conv.py
ADDED
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch import nn as nn
|
4 |
+
from torch.autograd import Function
|
5 |
+
from torch.autograd.function import once_differentiable
|
6 |
+
from torch.nn import functional as F
|
7 |
+
from torch.nn.modules.utils import _pair, _single
|
8 |
+
|
9 |
+
try:
|
10 |
+
from . import deform_conv_ext
|
11 |
+
except ImportError:
|
12 |
+
import os
|
13 |
+
BASICSR_JIT = os.getenv('BASICSR_JIT')
|
14 |
+
if BASICSR_JIT == 'True':
|
15 |
+
from torch.utils.cpp_extension import load
|
16 |
+
module_path = os.path.dirname(__file__)
|
17 |
+
deform_conv_ext = load(
|
18 |
+
'deform_conv',
|
19 |
+
sources=[
|
20 |
+
os.path.join(module_path, 'src', 'deform_conv_ext.cpp'),
|
21 |
+
os.path.join(module_path, 'src', 'deform_conv_cuda.cpp'),
|
22 |
+
os.path.join(module_path, 'src', 'deform_conv_cuda_kernel.cu'),
|
23 |
+
],
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
class DeformConvFunction(Function):
|
28 |
+
|
29 |
+
@staticmethod
|
30 |
+
def forward(ctx,
|
31 |
+
input,
|
32 |
+
offset,
|
33 |
+
weight,
|
34 |
+
stride=1,
|
35 |
+
padding=0,
|
36 |
+
dilation=1,
|
37 |
+
groups=1,
|
38 |
+
deformable_groups=1,
|
39 |
+
im2col_step=64):
|
40 |
+
if input is not None and input.dim() != 4:
|
41 |
+
raise ValueError(f'Expected 4D tensor as input, got {input.dim()}' 'D tensor instead.')
|
42 |
+
ctx.stride = _pair(stride)
|
43 |
+
ctx.padding = _pair(padding)
|
44 |
+
ctx.dilation = _pair(dilation)
|
45 |
+
ctx.groups = groups
|
46 |
+
ctx.deformable_groups = deformable_groups
|
47 |
+
ctx.im2col_step = im2col_step
|
48 |
+
|
49 |
+
ctx.save_for_backward(input, offset, weight)
|
50 |
+
|
51 |
+
output = input.new_empty(DeformConvFunction._output_size(input, weight, ctx.padding, ctx.dilation, ctx.stride))
|
52 |
+
|
53 |
+
ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones
|
54 |
+
|
55 |
+
if not input.is_cuda:
|
56 |
+
raise NotImplementedError
|
57 |
+
else:
|
58 |
+
cur_im2col_step = min(ctx.im2col_step, input.shape[0])
|
59 |
+
assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
|
60 |
+
deform_conv_ext.deform_conv_forward(input, weight,
|
61 |
+
offset, output, ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
|
62 |
+
weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
|
63 |
+
ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
|
64 |
+
ctx.deformable_groups, cur_im2col_step)
|
65 |
+
return output
|
66 |
+
|
67 |
+
@staticmethod
|
68 |
+
@once_differentiable
|
69 |
+
def backward(ctx, grad_output):
|
70 |
+
input, offset, weight = ctx.saved_tensors
|
71 |
+
|
72 |
+
grad_input = grad_offset = grad_weight = None
|
73 |
+
|
74 |
+
if not grad_output.is_cuda:
|
75 |
+
raise NotImplementedError
|
76 |
+
else:
|
77 |
+
cur_im2col_step = min(ctx.im2col_step, input.shape[0])
|
78 |
+
assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
|
79 |
+
|
80 |
+
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
|
81 |
+
grad_input = torch.zeros_like(input)
|
82 |
+
grad_offset = torch.zeros_like(offset)
|
83 |
+
deform_conv_ext.deform_conv_backward_input(input, offset, grad_output, grad_input,
|
84 |
+
grad_offset, weight, ctx.bufs_[0], weight.size(3),
|
85 |
+
weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
|
86 |
+
ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
|
87 |
+
ctx.deformable_groups, cur_im2col_step)
|
88 |
+
|
89 |
+
if ctx.needs_input_grad[2]:
|
90 |
+
grad_weight = torch.zeros_like(weight)
|
91 |
+
deform_conv_ext.deform_conv_backward_parameters(input, offset, grad_output, grad_weight,
|
92 |
+
ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
|
93 |
+
weight.size(2), ctx.stride[1], ctx.stride[0],
|
94 |
+
ctx.padding[1], ctx.padding[0], ctx.dilation[1],
|
95 |
+
ctx.dilation[0], ctx.groups, ctx.deformable_groups, 1,
|
96 |
+
cur_im2col_step)
|
97 |
+
|
98 |
+
return (grad_input, grad_offset, grad_weight, None, None, None, None, None)
|
99 |
+
|
100 |
+
@staticmethod
|
101 |
+
def _output_size(input, weight, padding, dilation, stride):
|
102 |
+
channels = weight.size(0)
|
103 |
+
output_size = (input.size(0), channels)
|
104 |
+
for d in range(input.dim() - 2):
|
105 |
+
in_size = input.size(d + 2)
|
106 |
+
pad = padding[d]
|
107 |
+
kernel = dilation[d] * (weight.size(d + 2) - 1) + 1
|
108 |
+
stride_ = stride[d]
|
109 |
+
output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
|
110 |
+
if not all(map(lambda s: s > 0, output_size)):
|
111 |
+
raise ValueError('convolution input is too small (output would be ' f'{"x".join(map(str, output_size))})')
|
112 |
+
return output_size
|
113 |
+
|
114 |
+
|
115 |
+
class ModulatedDeformConvFunction(Function):
|
116 |
+
|
117 |
+
@staticmethod
|
118 |
+
def forward(ctx,
|
119 |
+
input,
|
120 |
+
offset,
|
121 |
+
mask,
|
122 |
+
weight,
|
123 |
+
bias=None,
|
124 |
+
stride=1,
|
125 |
+
padding=0,
|
126 |
+
dilation=1,
|
127 |
+
groups=1,
|
128 |
+
deformable_groups=1):
|
129 |
+
ctx.stride = stride
|
130 |
+
ctx.padding = padding
|
131 |
+
ctx.dilation = dilation
|
132 |
+
ctx.groups = groups
|
133 |
+
ctx.deformable_groups = deformable_groups
|
134 |
+
ctx.with_bias = bias is not None
|
135 |
+
if not ctx.with_bias:
|
136 |
+
bias = input.new_empty(1) # fake tensor
|
137 |
+
if not input.is_cuda:
|
138 |
+
raise NotImplementedError
|
139 |
+
if weight.requires_grad or mask.requires_grad or offset.requires_grad \
|
140 |
+
or input.requires_grad:
|
141 |
+
ctx.save_for_backward(input, offset, mask, weight, bias)
|
142 |
+
output = input.new_empty(ModulatedDeformConvFunction._infer_shape(ctx, input, weight))
|
143 |
+
ctx._bufs = [input.new_empty(0), input.new_empty(0)]
|
144 |
+
deform_conv_ext.modulated_deform_conv_forward(input, weight, bias, ctx._bufs[0], offset, mask, output,
|
145 |
+
ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride,
|
146 |
+
ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
|
147 |
+
ctx.groups, ctx.deformable_groups, ctx.with_bias)
|
148 |
+
return output
|
149 |
+
|
150 |
+
@staticmethod
|
151 |
+
@once_differentiable
|
152 |
+
def backward(ctx, grad_output):
|
153 |
+
if not grad_output.is_cuda:
|
154 |
+
raise NotImplementedError
|
155 |
+
input, offset, mask, weight, bias = ctx.saved_tensors
|
156 |
+
grad_input = torch.zeros_like(input)
|
157 |
+
grad_offset = torch.zeros_like(offset)
|
158 |
+
grad_mask = torch.zeros_like(mask)
|
159 |
+
grad_weight = torch.zeros_like(weight)
|
160 |
+
grad_bias = torch.zeros_like(bias)
|
161 |
+
deform_conv_ext.modulated_deform_conv_backward(input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1],
|
162 |
+
grad_input, grad_weight, grad_bias, grad_offset, grad_mask,
|
163 |
+
grad_output, weight.shape[2], weight.shape[3], ctx.stride,
|
164 |
+
ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
|
165 |
+
ctx.groups, ctx.deformable_groups, ctx.with_bias)
|
166 |
+
if not ctx.with_bias:
|
167 |
+
grad_bias = None
|
168 |
+
|
169 |
+
return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None, None)
|
170 |
+
|
171 |
+
@staticmethod
|
172 |
+
def _infer_shape(ctx, input, weight):
|
173 |
+
n = input.size(0)
|
174 |
+
channels_out = weight.size(0)
|
175 |
+
height, width = input.shape[2:4]
|
176 |
+
kernel_h, kernel_w = weight.shape[2:4]
|
177 |
+
height_out = (height + 2 * ctx.padding - (ctx.dilation * (kernel_h - 1) + 1)) // ctx.stride + 1
|
178 |
+
width_out = (width + 2 * ctx.padding - (ctx.dilation * (kernel_w - 1) + 1)) // ctx.stride + 1
|
179 |
+
return n, channels_out, height_out, width_out
|
180 |
+
|
181 |
+
|
182 |
+
deform_conv = DeformConvFunction.apply
|
183 |
+
modulated_deform_conv = ModulatedDeformConvFunction.apply
|
184 |
+
|
185 |
+
|
186 |
+
class DeformConv(nn.Module):
|
187 |
+
|
188 |
+
def __init__(self,
|
189 |
+
in_channels,
|
190 |
+
out_channels,
|
191 |
+
kernel_size,
|
192 |
+
stride=1,
|
193 |
+
padding=0,
|
194 |
+
dilation=1,
|
195 |
+
groups=1,
|
196 |
+
deformable_groups=1,
|
197 |
+
bias=False):
|
198 |
+
super(DeformConv, self).__init__()
|
199 |
+
|
200 |
+
assert not bias
|
201 |
+
assert in_channels % groups == 0, \
|
202 |
+
f'in_channels {in_channels} is not divisible by groups {groups}'
|
203 |
+
assert out_channels % groups == 0, \
|
204 |
+
f'out_channels {out_channels} is not divisible ' \
|
205 |
+
f'by groups {groups}'
|
206 |
+
|
207 |
+
self.in_channels = in_channels
|
208 |
+
self.out_channels = out_channels
|
209 |
+
self.kernel_size = _pair(kernel_size)
|
210 |
+
self.stride = _pair(stride)
|
211 |
+
self.padding = _pair(padding)
|
212 |
+
self.dilation = _pair(dilation)
|
213 |
+
self.groups = groups
|
214 |
+
self.deformable_groups = deformable_groups
|
215 |
+
# enable compatibility with nn.Conv2d
|
216 |
+
self.transposed = False
|
217 |
+
self.output_padding = _single(0)
|
218 |
+
|
219 |
+
self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size))
|
220 |
+
|
221 |
+
self.reset_parameters()
|
222 |
+
|
223 |
+
def reset_parameters(self):
|
224 |
+
n = self.in_channels
|
225 |
+
for k in self.kernel_size:
|
226 |
+
n *= k
|
227 |
+
stdv = 1. / math.sqrt(n)
|
228 |
+
self.weight.data.uniform_(-stdv, stdv)
|
229 |
+
|
230 |
+
def forward(self, x, offset):
|
231 |
+
# To fix an assert error in deform_conv_cuda.cpp:128
|
232 |
+
# input image is smaller than kernel
|
233 |
+
input_pad = (x.size(2) < self.kernel_size[0] or x.size(3) < self.kernel_size[1])
|
234 |
+
if input_pad:
|
235 |
+
pad_h = max(self.kernel_size[0] - x.size(2), 0)
|
236 |
+
pad_w = max(self.kernel_size[1] - x.size(3), 0)
|
237 |
+
x = F.pad(x, (0, pad_w, 0, pad_h), 'constant', 0).contiguous()
|
238 |
+
offset = F.pad(offset, (0, pad_w, 0, pad_h), 'constant', 0).contiguous()
|
239 |
+
out = deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups,
|
240 |
+
self.deformable_groups)
|
241 |
+
if input_pad:
|
242 |
+
out = out[:, :, :out.size(2) - pad_h, :out.size(3) - pad_w].contiguous()
|
243 |
+
return out
|
244 |
+
|
245 |
+
|
246 |
+
class DeformConvPack(DeformConv):
|
247 |
+
"""A Deformable Conv Encapsulation that acts as normal Conv layers.
|
248 |
+
|
249 |
+
Args:
|
250 |
+
in_channels (int): Same as nn.Conv2d.
|
251 |
+
out_channels (int): Same as nn.Conv2d.
|
252 |
+
kernel_size (int or tuple[int]): Same as nn.Conv2d.
|
253 |
+
stride (int or tuple[int]): Same as nn.Conv2d.
|
254 |
+
padding (int or tuple[int]): Same as nn.Conv2d.
|
255 |
+
dilation (int or tuple[int]): Same as nn.Conv2d.
|
256 |
+
groups (int): Same as nn.Conv2d.
|
257 |
+
bias (bool or str): If specified as `auto`, it will be decided by the
|
258 |
+
norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
|
259 |
+
False.
|
260 |
+
"""
|
261 |
+
|
262 |
+
_version = 2
|
263 |
+
|
264 |
+
def __init__(self, *args, **kwargs):
|
265 |
+
super(DeformConvPack, self).__init__(*args, **kwargs)
|
266 |
+
|
267 |
+
self.conv_offset = nn.Conv2d(
|
268 |
+
self.in_channels,
|
269 |
+
self.deformable_groups * 2 * self.kernel_size[0] * self.kernel_size[1],
|
270 |
+
kernel_size=self.kernel_size,
|
271 |
+
stride=_pair(self.stride),
|
272 |
+
padding=_pair(self.padding),
|
273 |
+
dilation=_pair(self.dilation),
|
274 |
+
bias=True)
|
275 |
+
self.init_offset()
|
276 |
+
|
277 |
+
def init_offset(self):
|
278 |
+
self.conv_offset.weight.data.zero_()
|
279 |
+
self.conv_offset.bias.data.zero_()
|
280 |
+
|
281 |
+
def forward(self, x):
|
282 |
+
offset = self.conv_offset(x)
|
283 |
+
return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups,
|
284 |
+
self.deformable_groups)
|
285 |
+
|
286 |
+
|
287 |
+
class ModulatedDeformConv(nn.Module):
|
288 |
+
|
289 |
+
def __init__(self,
|
290 |
+
in_channels,
|
291 |
+
out_channels,
|
292 |
+
kernel_size,
|
293 |
+
stride=1,
|
294 |
+
padding=0,
|
295 |
+
dilation=1,
|
296 |
+
groups=1,
|
297 |
+
deformable_groups=1,
|
298 |
+
bias=True):
|
299 |
+
super(ModulatedDeformConv, self).__init__()
|
300 |
+
self.in_channels = in_channels
|
301 |
+
self.out_channels = out_channels
|
302 |
+
self.kernel_size = _pair(kernel_size)
|
303 |
+
self.stride = stride
|
304 |
+
self.padding = padding
|
305 |
+
self.dilation = dilation
|
306 |
+
self.groups = groups
|
307 |
+
self.deformable_groups = deformable_groups
|
308 |
+
self.with_bias = bias
|
309 |
+
# enable compatibility with nn.Conv2d
|
310 |
+
self.transposed = False
|
311 |
+
self.output_padding = _single(0)
|
312 |
+
|
313 |
+
self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size))
|
314 |
+
if bias:
|
315 |
+
self.bias = nn.Parameter(torch.Tensor(out_channels))
|
316 |
+
else:
|
317 |
+
self.register_parameter('bias', None)
|
318 |
+
self.init_weights()
|
319 |
+
|
320 |
+
def init_weights(self):
|
321 |
+
n = self.in_channels
|
322 |
+
for k in self.kernel_size:
|
323 |
+
n *= k
|
324 |
+
stdv = 1. / math.sqrt(n)
|
325 |
+
self.weight.data.uniform_(-stdv, stdv)
|
326 |
+
if self.bias is not None:
|
327 |
+
self.bias.data.zero_()
|
328 |
+
|
329 |
+
def forward(self, x, offset, mask):
|
330 |
+
return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation,
|
331 |
+
self.groups, self.deformable_groups)
|
332 |
+
|
333 |
+
|
334 |
+
class ModulatedDeformConvPack(ModulatedDeformConv):
|
335 |
+
"""A ModulatedDeformable Conv Encapsulation that acts as normal Conv layers.
|
336 |
+
|
337 |
+
Args:
|
338 |
+
in_channels (int): Same as nn.Conv2d.
|
339 |
+
out_channels (int): Same as nn.Conv2d.
|
340 |
+
kernel_size (int or tuple[int]): Same as nn.Conv2d.
|
341 |
+
stride (int or tuple[int]): Same as nn.Conv2d.
|
342 |
+
padding (int or tuple[int]): Same as nn.Conv2d.
|
343 |
+
dilation (int or tuple[int]): Same as nn.Conv2d.
|
344 |
+
groups (int): Same as nn.Conv2d.
|
345 |
+
bias (bool or str): If specified as `auto`, it will be decided by the
|
346 |
+
norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
|
347 |
+
False.
|
348 |
+
"""
|
349 |
+
|
350 |
+
_version = 2
|
351 |
+
|
352 |
+
def __init__(self, *args, **kwargs):
|
353 |
+
super(ModulatedDeformConvPack, self).__init__(*args, **kwargs)
|
354 |
+
|
355 |
+
self.conv_offset = nn.Conv2d(
|
356 |
+
self.in_channels,
|
357 |
+
self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
|
358 |
+
kernel_size=self.kernel_size,
|
359 |
+
stride=_pair(self.stride),
|
360 |
+
padding=_pair(self.padding),
|
361 |
+
dilation=_pair(self.dilation),
|
362 |
+
bias=True)
|
363 |
+
self.init_weights()
|
364 |
+
|
365 |
+
def init_weights(self):
|
366 |
+
super(ModulatedDeformConvPack, self).init_weights()
|
367 |
+
if hasattr(self, 'conv_offset'):
|
368 |
+
self.conv_offset.weight.data.zero_()
|
369 |
+
self.conv_offset.bias.data.zero_()
|
370 |
+
|
371 |
+
def forward(self, x):
|
372 |
+
out = self.conv_offset(x)
|
373 |
+
o1, o2, mask = torch.chunk(out, 3, dim=1)
|
374 |
+
offset = torch.cat((o1, o2), dim=1)
|
375 |
+
mask = torch.sigmoid(mask)
|
376 |
+
return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation,
|
377 |
+
self.groups, self.deformable_groups)
|
blissful_tuner/codeformer/basicsr/ops/dcn/src/deform_conv_cuda.cpp
ADDED
@@ -0,0 +1,685 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// modify from
|
2 |
+
// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
|
3 |
+
|
4 |
+
#include <torch/extension.h>
|
5 |
+
#include <ATen/DeviceGuard.h>
|
6 |
+
|
7 |
+
#include <cmath>
|
8 |
+
#include <vector>
|
9 |
+
|
10 |
+
void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset,
|
11 |
+
const int channels, const int height, const int width,
|
12 |
+
const int ksize_h, const int ksize_w, const int pad_h,
|
13 |
+
const int pad_w, const int stride_h, const int stride_w,
|
14 |
+
const int dilation_h, const int dilation_w,
|
15 |
+
const int parallel_imgs, const int deformable_group,
|
16 |
+
at::Tensor data_col);
|
17 |
+
|
18 |
+
void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset,
|
19 |
+
const int channels, const int height, const int width,
|
20 |
+
const int ksize_h, const int ksize_w, const int pad_h,
|
21 |
+
const int pad_w, const int stride_h, const int stride_w,
|
22 |
+
const int dilation_h, const int dilation_w,
|
23 |
+
const int parallel_imgs, const int deformable_group,
|
24 |
+
at::Tensor grad_im);
|
25 |
+
|
26 |
+
void deformable_col2im_coord(
|
27 |
+
const at::Tensor data_col, const at::Tensor data_im,
|
28 |
+
const at::Tensor data_offset, const int channels, const int height,
|
29 |
+
const int width, const int ksize_h, const int ksize_w, const int pad_h,
|
30 |
+
const int pad_w, const int stride_h, const int stride_w,
|
31 |
+
const int dilation_h, const int dilation_w, const int parallel_imgs,
|
32 |
+
const int deformable_group, at::Tensor grad_offset);
|
33 |
+
|
34 |
+
void modulated_deformable_im2col_cuda(
|
35 |
+
const at::Tensor data_im, const at::Tensor data_offset,
|
36 |
+
const at::Tensor data_mask, const int batch_size, const int channels,
|
37 |
+
const int height_im, const int width_im, const int height_col,
|
38 |
+
const int width_col, const int kernel_h, const int kenerl_w,
|
39 |
+
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
40 |
+
const int dilation_h, const int dilation_w, const int deformable_group,
|
41 |
+
at::Tensor data_col);
|
42 |
+
|
43 |
+
void modulated_deformable_col2im_cuda(
|
44 |
+
const at::Tensor data_col, const at::Tensor data_offset,
|
45 |
+
const at::Tensor data_mask, const int batch_size, const int channels,
|
46 |
+
const int height_im, const int width_im, const int height_col,
|
47 |
+
const int width_col, const int kernel_h, const int kenerl_w,
|
48 |
+
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
49 |
+
const int dilation_h, const int dilation_w, const int deformable_group,
|
50 |
+
at::Tensor grad_im);
|
51 |
+
|
52 |
+
void modulated_deformable_col2im_coord_cuda(
|
53 |
+
const at::Tensor data_col, const at::Tensor data_im,
|
54 |
+
const at::Tensor data_offset, const at::Tensor data_mask,
|
55 |
+
const int batch_size, const int channels, const int height_im,
|
56 |
+
const int width_im, const int height_col, const int width_col,
|
57 |
+
const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w,
|
58 |
+
const int stride_h, const int stride_w, const int dilation_h,
|
59 |
+
const int dilation_w, const int deformable_group, at::Tensor grad_offset,
|
60 |
+
at::Tensor grad_mask);
|
61 |
+
|
62 |
+
void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput,
|
63 |
+
at::Tensor weight, int kH, int kW, int dH, int dW, int padH,
|
64 |
+
int padW, int dilationH, int dilationW, int group,
|
65 |
+
int deformable_group) {
|
66 |
+
TORCH_CHECK(weight.ndimension() == 4,
|
67 |
+
"4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, "
|
68 |
+
"but got: %s",
|
69 |
+
weight.ndimension());
|
70 |
+
|
71 |
+
TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
|
72 |
+
|
73 |
+
TORCH_CHECK(kW > 0 && kH > 0,
|
74 |
+
"kernel size should be greater than zero, but got kH: %d kW: %d", kH,
|
75 |
+
kW);
|
76 |
+
|
77 |
+
TORCH_CHECK((weight.size(2) == kH && weight.size(3) == kW),
|
78 |
+
"kernel size should be consistent with weight, ",
|
79 |
+
"but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH,
|
80 |
+
kW, weight.size(2), weight.size(3));
|
81 |
+
|
82 |
+
TORCH_CHECK(dW > 0 && dH > 0,
|
83 |
+
"stride should be greater than zero, but got dH: %d dW: %d", dH, dW);
|
84 |
+
|
85 |
+
TORCH_CHECK(
|
86 |
+
dilationW > 0 && dilationH > 0,
|
87 |
+
"dilation should be greater than 0, but got dilationH: %d dilationW: %d",
|
88 |
+
dilationH, dilationW);
|
89 |
+
|
90 |
+
int ndim = input.ndimension();
|
91 |
+
int dimf = 0;
|
92 |
+
int dimh = 1;
|
93 |
+
int dimw = 2;
|
94 |
+
|
95 |
+
if (ndim == 4) {
|
96 |
+
dimf++;
|
97 |
+
dimh++;
|
98 |
+
dimw++;
|
99 |
+
}
|
100 |
+
|
101 |
+
TORCH_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s",
|
102 |
+
ndim);
|
103 |
+
|
104 |
+
long nInputPlane = weight.size(1) * group;
|
105 |
+
long inputHeight = input.size(dimh);
|
106 |
+
long inputWidth = input.size(dimw);
|
107 |
+
long nOutputPlane = weight.size(0);
|
108 |
+
long outputHeight =
|
109 |
+
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
|
110 |
+
long outputWidth =
|
111 |
+
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
|
112 |
+
|
113 |
+
TORCH_CHECK(nInputPlane % deformable_group == 0,
|
114 |
+
"input channels must divide deformable group size");
|
115 |
+
|
116 |
+
if (outputWidth < 1 || outputHeight < 1)
|
117 |
+
AT_ERROR(
|
118 |
+
"Given input size: (%ld x %ld x %ld). "
|
119 |
+
"Calculated output size: (%ld x %ld x %ld). Output size is too small",
|
120 |
+
nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight,
|
121 |
+
outputWidth);
|
122 |
+
|
123 |
+
TORCH_CHECK(input.size(1) == nInputPlane,
|
124 |
+
"invalid number of input planes, expected: %d, but got: %d",
|
125 |
+
nInputPlane, input.size(1));
|
126 |
+
|
127 |
+
TORCH_CHECK((inputHeight >= kH && inputWidth >= kW),
|
128 |
+
"input image is smaller than kernel");
|
129 |
+
|
130 |
+
TORCH_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth),
|
131 |
+
"invalid spatial size of offset, expected height: %d width: %d, but "
|
132 |
+
"got height: %d width: %d",
|
133 |
+
outputHeight, outputWidth, offset.size(2), offset.size(3));
|
134 |
+
|
135 |
+
TORCH_CHECK((offset.size(1) == deformable_group * 2 * kH * kW),
|
136 |
+
"invalid number of channels of offset");
|
137 |
+
|
138 |
+
if (gradOutput != NULL) {
|
139 |
+
TORCH_CHECK(gradOutput->size(dimf) == nOutputPlane,
|
140 |
+
"invalid number of gradOutput planes, expected: %d, but got: %d",
|
141 |
+
nOutputPlane, gradOutput->size(dimf));
|
142 |
+
|
143 |
+
TORCH_CHECK((gradOutput->size(dimh) == outputHeight &&
|
144 |
+
gradOutput->size(dimw) == outputWidth),
|
145 |
+
"invalid size of gradOutput, expected height: %d width: %d , but "
|
146 |
+
"got height: %d width: %d",
|
147 |
+
outputHeight, outputWidth, gradOutput->size(dimh),
|
148 |
+
gradOutput->size(dimw));
|
149 |
+
}
|
150 |
+
}
|
151 |
+
|
152 |
+
int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
|
153 |
+
at::Tensor offset, at::Tensor output,
|
154 |
+
at::Tensor columns, at::Tensor ones, int kW,
|
155 |
+
int kH, int dW, int dH, int padW, int padH,
|
156 |
+
int dilationW, int dilationH, int group,
|
157 |
+
int deformable_group, int im2col_step) {
|
158 |
+
// todo: resize columns to include im2col: done
|
159 |
+
// todo: add im2col_step as input
|
160 |
+
// todo: add new output buffer and transpose it to output (or directly
|
161 |
+
// transpose output) todo: possibly change data indexing because of
|
162 |
+
// parallel_imgs
|
163 |
+
|
164 |
+
shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW,
|
165 |
+
dilationH, dilationW, group, deformable_group);
|
166 |
+
at::DeviceGuard guard(input.device());
|
167 |
+
|
168 |
+
input = input.contiguous();
|
169 |
+
offset = offset.contiguous();
|
170 |
+
weight = weight.contiguous();
|
171 |
+
|
172 |
+
int batch = 1;
|
173 |
+
if (input.ndimension() == 3) {
|
174 |
+
// Force batch
|
175 |
+
batch = 0;
|
176 |
+
input.unsqueeze_(0);
|
177 |
+
offset.unsqueeze_(0);
|
178 |
+
}
|
179 |
+
|
180 |
+
// todo: assert batchsize dividable by im2col_step
|
181 |
+
|
182 |
+
long batchSize = input.size(0);
|
183 |
+
long nInputPlane = input.size(1);
|
184 |
+
long inputHeight = input.size(2);
|
185 |
+
long inputWidth = input.size(3);
|
186 |
+
|
187 |
+
long nOutputPlane = weight.size(0);
|
188 |
+
|
189 |
+
long outputWidth =
|
190 |
+
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
|
191 |
+
long outputHeight =
|
192 |
+
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
|
193 |
+
|
194 |
+
TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
|
195 |
+
|
196 |
+
output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane,
|
197 |
+
outputHeight, outputWidth});
|
198 |
+
columns = at::zeros(
|
199 |
+
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
|
200 |
+
input.options());
|
201 |
+
|
202 |
+
if (ones.ndimension() != 2 ||
|
203 |
+
ones.size(0) * ones.size(1) < outputHeight * outputWidth) {
|
204 |
+
ones = at::ones({outputHeight, outputWidth}, input.options());
|
205 |
+
}
|
206 |
+
|
207 |
+
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
|
208 |
+
inputHeight, inputWidth});
|
209 |
+
offset =
|
210 |
+
offset.view({batchSize / im2col_step, im2col_step,
|
211 |
+
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
212 |
+
|
213 |
+
at::Tensor output_buffer =
|
214 |
+
at::zeros({batchSize / im2col_step, nOutputPlane,
|
215 |
+
im2col_step * outputHeight, outputWidth},
|
216 |
+
output.options());
|
217 |
+
|
218 |
+
output_buffer = output_buffer.view(
|
219 |
+
{output_buffer.size(0), group, output_buffer.size(1) / group,
|
220 |
+
output_buffer.size(2), output_buffer.size(3)});
|
221 |
+
|
222 |
+
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
|
223 |
+
deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
|
224 |
+
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
|
225 |
+
dilationW, im2col_step, deformable_group, columns);
|
226 |
+
|
227 |
+
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
228 |
+
weight = weight.view({group, weight.size(0) / group, weight.size(1),
|
229 |
+
weight.size(2), weight.size(3)});
|
230 |
+
|
231 |
+
for (int g = 0; g < group; g++) {
|
232 |
+
output_buffer[elt][g] = output_buffer[elt][g]
|
233 |
+
.flatten(1)
|
234 |
+
.addmm_(weight[g].flatten(1), columns[g])
|
235 |
+
.view_as(output_buffer[elt][g]);
|
236 |
+
}
|
237 |
+
}
|
238 |
+
|
239 |
+
output_buffer = output_buffer.view(
|
240 |
+
{output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2),
|
241 |
+
output_buffer.size(3), output_buffer.size(4)});
|
242 |
+
|
243 |
+
output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane,
|
244 |
+
im2col_step, outputHeight, outputWidth});
|
245 |
+
output_buffer.transpose_(1, 2);
|
246 |
+
output.copy_(output_buffer);
|
247 |
+
output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth});
|
248 |
+
|
249 |
+
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
250 |
+
offset = offset.view(
|
251 |
+
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
252 |
+
|
253 |
+
if (batch == 0) {
|
254 |
+
output = output.view({nOutputPlane, outputHeight, outputWidth});
|
255 |
+
input = input.view({nInputPlane, inputHeight, inputWidth});
|
256 |
+
offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
|
257 |
+
}
|
258 |
+
|
259 |
+
return 1;
|
260 |
+
}
|
261 |
+
|
262 |
+
int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
|
263 |
+
at::Tensor gradOutput, at::Tensor gradInput,
|
264 |
+
at::Tensor gradOffset, at::Tensor weight,
|
265 |
+
at::Tensor columns, int kW, int kH, int dW,
|
266 |
+
int dH, int padW, int padH, int dilationW,
|
267 |
+
int dilationH, int group,
|
268 |
+
int deformable_group, int im2col_step) {
|
269 |
+
shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW,
|
270 |
+
dilationH, dilationW, group, deformable_group);
|
271 |
+
at::DeviceGuard guard(input.device());
|
272 |
+
|
273 |
+
input = input.contiguous();
|
274 |
+
offset = offset.contiguous();
|
275 |
+
gradOutput = gradOutput.contiguous();
|
276 |
+
weight = weight.contiguous();
|
277 |
+
|
278 |
+
int batch = 1;
|
279 |
+
|
280 |
+
if (input.ndimension() == 3) {
|
281 |
+
// Force batch
|
282 |
+
batch = 0;
|
283 |
+
input = input.view({1, input.size(0), input.size(1), input.size(2)});
|
284 |
+
offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)});
|
285 |
+
gradOutput = gradOutput.view(
|
286 |
+
{1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
|
287 |
+
}
|
288 |
+
|
289 |
+
long batchSize = input.size(0);
|
290 |
+
long nInputPlane = input.size(1);
|
291 |
+
long inputHeight = input.size(2);
|
292 |
+
long inputWidth = input.size(3);
|
293 |
+
|
294 |
+
long nOutputPlane = weight.size(0);
|
295 |
+
|
296 |
+
long outputWidth =
|
297 |
+
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
|
298 |
+
long outputHeight =
|
299 |
+
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
|
300 |
+
|
301 |
+
TORCH_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset");
|
302 |
+
gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
303 |
+
columns = at::zeros(
|
304 |
+
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
|
305 |
+
input.options());
|
306 |
+
|
307 |
+
// change order of grad output
|
308 |
+
gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
|
309 |
+
nOutputPlane, outputHeight, outputWidth});
|
310 |
+
gradOutput.transpose_(1, 2);
|
311 |
+
|
312 |
+
gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane,
|
313 |
+
inputHeight, inputWidth});
|
314 |
+
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
|
315 |
+
inputHeight, inputWidth});
|
316 |
+
gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step,
|
317 |
+
deformable_group * 2 * kH * kW, outputHeight,
|
318 |
+
outputWidth});
|
319 |
+
offset =
|
320 |
+
offset.view({batchSize / im2col_step, im2col_step,
|
321 |
+
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
322 |
+
|
323 |
+
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
|
324 |
+
// divide into groups
|
325 |
+
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
326 |
+
weight = weight.view({group, weight.size(0) / group, weight.size(1),
|
327 |
+
weight.size(2), weight.size(3)});
|
328 |
+
gradOutput = gradOutput.view(
|
329 |
+
{gradOutput.size(0), group, gradOutput.size(1) / group,
|
330 |
+
gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)});
|
331 |
+
|
332 |
+
for (int g = 0; g < group; g++) {
|
333 |
+
columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
|
334 |
+
gradOutput[elt][g].flatten(1), 0.0f, 1.0f);
|
335 |
+
}
|
336 |
+
|
337 |
+
columns =
|
338 |
+
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
339 |
+
gradOutput = gradOutput.view(
|
340 |
+
{gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2),
|
341 |
+
gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)});
|
342 |
+
|
343 |
+
deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane,
|
344 |
+
inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
|
345 |
+
dilationH, dilationW, im2col_step, deformable_group,
|
346 |
+
gradOffset[elt]);
|
347 |
+
|
348 |
+
deformable_col2im(columns, offset[elt], nInputPlane, inputHeight,
|
349 |
+
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
|
350 |
+
dilationW, im2col_step, deformable_group, gradInput[elt]);
|
351 |
+
}
|
352 |
+
|
353 |
+
gradOutput.transpose_(1, 2);
|
354 |
+
gradOutput =
|
355 |
+
gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
|
356 |
+
|
357 |
+
gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
358 |
+
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
359 |
+
gradOffset = gradOffset.view(
|
360 |
+
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
361 |
+
offset = offset.view(
|
362 |
+
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
363 |
+
|
364 |
+
if (batch == 0) {
|
365 |
+
gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
|
366 |
+
input = input.view({nInputPlane, inputHeight, inputWidth});
|
367 |
+
gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth});
|
368 |
+
offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
|
369 |
+
gradOffset =
|
370 |
+
gradOffset.view({offset.size(1), offset.size(2), offset.size(3)});
|
371 |
+
}
|
372 |
+
|
373 |
+
return 1;
|
374 |
+
}
|
375 |
+
|
376 |
+
int deform_conv_backward_parameters_cuda(
|
377 |
+
at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
|
378 |
+
at::Tensor gradWeight, // at::Tensor gradBias,
|
379 |
+
at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
|
380 |
+
int padW, int padH, int dilationW, int dilationH, int group,
|
381 |
+
int deformable_group, float scale, int im2col_step) {
|
382 |
+
// todo: transpose and reshape outGrad
|
383 |
+
// todo: reshape columns
|
384 |
+
// todo: add im2col_step as input
|
385 |
+
|
386 |
+
shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH,
|
387 |
+
padW, dilationH, dilationW, group, deformable_group);
|
388 |
+
at::DeviceGuard guard(input.device());
|
389 |
+
|
390 |
+
input = input.contiguous();
|
391 |
+
offset = offset.contiguous();
|
392 |
+
gradOutput = gradOutput.contiguous();
|
393 |
+
|
394 |
+
int batch = 1;
|
395 |
+
|
396 |
+
if (input.ndimension() == 3) {
|
397 |
+
// Force batch
|
398 |
+
batch = 0;
|
399 |
+
input = input.view(
|
400 |
+
at::IntList({1, input.size(0), input.size(1), input.size(2)}));
|
401 |
+
gradOutput = gradOutput.view(
|
402 |
+
{1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
|
403 |
+
}
|
404 |
+
|
405 |
+
long batchSize = input.size(0);
|
406 |
+
long nInputPlane = input.size(1);
|
407 |
+
long inputHeight = input.size(2);
|
408 |
+
long inputWidth = input.size(3);
|
409 |
+
|
410 |
+
long nOutputPlane = gradWeight.size(0);
|
411 |
+
|
412 |
+
long outputWidth =
|
413 |
+
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
|
414 |
+
long outputHeight =
|
415 |
+
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
|
416 |
+
|
417 |
+
TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
|
418 |
+
|
419 |
+
columns = at::zeros(
|
420 |
+
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
|
421 |
+
input.options());
|
422 |
+
|
423 |
+
gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
|
424 |
+
nOutputPlane, outputHeight, outputWidth});
|
425 |
+
gradOutput.transpose_(1, 2);
|
426 |
+
|
427 |
+
at::Tensor gradOutputBuffer = at::zeros_like(gradOutput);
|
428 |
+
gradOutputBuffer =
|
429 |
+
gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step,
|
430 |
+
outputHeight, outputWidth});
|
431 |
+
gradOutputBuffer.copy_(gradOutput);
|
432 |
+
gradOutputBuffer =
|
433 |
+
gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane,
|
434 |
+
im2col_step * outputHeight, outputWidth});
|
435 |
+
|
436 |
+
gradOutput.transpose_(1, 2);
|
437 |
+
gradOutput =
|
438 |
+
gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
|
439 |
+
|
440 |
+
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
|
441 |
+
inputHeight, inputWidth});
|
442 |
+
offset =
|
443 |
+
offset.view({batchSize / im2col_step, im2col_step,
|
444 |
+
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
445 |
+
|
446 |
+
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
|
447 |
+
deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
|
448 |
+
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
|
449 |
+
dilationW, im2col_step, deformable_group, columns);
|
450 |
+
|
451 |
+
// divide into group
|
452 |
+
gradOutputBuffer = gradOutputBuffer.view(
|
453 |
+
{gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group,
|
454 |
+
gradOutputBuffer.size(2), gradOutputBuffer.size(3)});
|
455 |
+
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
456 |
+
gradWeight =
|
457 |
+
gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1),
|
458 |
+
gradWeight.size(2), gradWeight.size(3)});
|
459 |
+
|
460 |
+
for (int g = 0; g < group; g++) {
|
461 |
+
gradWeight[g] = gradWeight[g]
|
462 |
+
.flatten(1)
|
463 |
+
.addmm_(gradOutputBuffer[elt][g].flatten(1),
|
464 |
+
columns[g].transpose(1, 0), 1.0, scale)
|
465 |
+
.view_as(gradWeight[g]);
|
466 |
+
}
|
467 |
+
gradOutputBuffer = gradOutputBuffer.view(
|
468 |
+
{gradOutputBuffer.size(0),
|
469 |
+
gradOutputBuffer.size(1) * gradOutputBuffer.size(2),
|
470 |
+
gradOutputBuffer.size(3), gradOutputBuffer.size(4)});
|
471 |
+
columns =
|
472 |
+
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
473 |
+
gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1),
|
474 |
+
gradWeight.size(2), gradWeight.size(3),
|
475 |
+
gradWeight.size(4)});
|
476 |
+
}
|
477 |
+
|
478 |
+
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
479 |
+
offset = offset.view(
|
480 |
+
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
481 |
+
|
482 |
+
if (batch == 0) {
|
483 |
+
gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
|
484 |
+
input = input.view({nInputPlane, inputHeight, inputWidth});
|
485 |
+
}
|
486 |
+
|
487 |
+
return 1;
|
488 |
+
}
|
489 |
+
|
490 |
+
void modulated_deform_conv_cuda_forward(
|
491 |
+
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
|
492 |
+
at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
|
493 |
+
int kernel_h, int kernel_w, const int stride_h, const int stride_w,
|
494 |
+
const int pad_h, const int pad_w, const int dilation_h,
|
495 |
+
const int dilation_w, const int group, const int deformable_group,
|
496 |
+
const bool with_bias) {
|
497 |
+
TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
|
498 |
+
TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
|
499 |
+
at::DeviceGuard guard(input.device());
|
500 |
+
|
501 |
+
const int batch = input.size(0);
|
502 |
+
const int channels = input.size(1);
|
503 |
+
const int height = input.size(2);
|
504 |
+
const int width = input.size(3);
|
505 |
+
|
506 |
+
const int channels_out = weight.size(0);
|
507 |
+
const int channels_kernel = weight.size(1);
|
508 |
+
const int kernel_h_ = weight.size(2);
|
509 |
+
const int kernel_w_ = weight.size(3);
|
510 |
+
|
511 |
+
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
|
512 |
+
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
|
513 |
+
kernel_h_, kernel_w, kernel_h_, kernel_w_);
|
514 |
+
if (channels != channels_kernel * group)
|
515 |
+
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
|
516 |
+
channels, channels_kernel * group);
|
517 |
+
|
518 |
+
const int height_out =
|
519 |
+
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
|
520 |
+
const int width_out =
|
521 |
+
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
|
522 |
+
|
523 |
+
if (ones.ndimension() != 2 ||
|
524 |
+
ones.size(0) * ones.size(1) < height_out * width_out) {
|
525 |
+
// Resize plane and fill with ones...
|
526 |
+
ones = at::ones({height_out, width_out}, input.options());
|
527 |
+
}
|
528 |
+
|
529 |
+
// resize output
|
530 |
+
output = output.view({batch, channels_out, height_out, width_out}).zero_();
|
531 |
+
// resize temporary columns
|
532 |
+
columns =
|
533 |
+
at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out},
|
534 |
+
input.options());
|
535 |
+
|
536 |
+
output = output.view({output.size(0), group, output.size(1) / group,
|
537 |
+
output.size(2), output.size(3)});
|
538 |
+
|
539 |
+
for (int b = 0; b < batch; b++) {
|
540 |
+
modulated_deformable_im2col_cuda(
|
541 |
+
input[b], offset[b], mask[b], 1, channels, height, width, height_out,
|
542 |
+
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
|
543 |
+
dilation_h, dilation_w, deformable_group, columns);
|
544 |
+
|
545 |
+
// divide into group
|
546 |
+
weight = weight.view({group, weight.size(0) / group, weight.size(1),
|
547 |
+
weight.size(2), weight.size(3)});
|
548 |
+
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
549 |
+
|
550 |
+
for (int g = 0; g < group; g++) {
|
551 |
+
output[b][g] = output[b][g]
|
552 |
+
.flatten(1)
|
553 |
+
.addmm_(weight[g].flatten(1), columns[g])
|
554 |
+
.view_as(output[b][g]);
|
555 |
+
}
|
556 |
+
|
557 |
+
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
|
558 |
+
weight.size(3), weight.size(4)});
|
559 |
+
columns =
|
560 |
+
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
561 |
+
}
|
562 |
+
|
563 |
+
output = output.view({output.size(0), output.size(1) * output.size(2),
|
564 |
+
output.size(3), output.size(4)});
|
565 |
+
|
566 |
+
if (with_bias) {
|
567 |
+
output += bias.view({1, bias.size(0), 1, 1});
|
568 |
+
}
|
569 |
+
}
|
570 |
+
|
571 |
+
void modulated_deform_conv_cuda_backward(
|
572 |
+
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
|
573 |
+
at::Tensor offset, at::Tensor mask, at::Tensor columns,
|
574 |
+
at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
|
575 |
+
at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
|
576 |
+
int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
|
577 |
+
int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
|
578 |
+
const bool with_bias) {
|
579 |
+
TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
|
580 |
+
TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
|
581 |
+
at::DeviceGuard guard(input.device());
|
582 |
+
|
583 |
+
const int batch = input.size(0);
|
584 |
+
const int channels = input.size(1);
|
585 |
+
const int height = input.size(2);
|
586 |
+
const int width = input.size(3);
|
587 |
+
|
588 |
+
const int channels_kernel = weight.size(1);
|
589 |
+
const int kernel_h_ = weight.size(2);
|
590 |
+
const int kernel_w_ = weight.size(3);
|
591 |
+
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
|
592 |
+
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
|
593 |
+
kernel_h_, kernel_w, kernel_h_, kernel_w_);
|
594 |
+
if (channels != channels_kernel * group)
|
595 |
+
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
|
596 |
+
channels, channels_kernel * group);
|
597 |
+
|
598 |
+
const int height_out =
|
599 |
+
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
|
600 |
+
const int width_out =
|
601 |
+
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
|
602 |
+
|
603 |
+
if (ones.ndimension() != 2 ||
|
604 |
+
ones.size(0) * ones.size(1) < height_out * width_out) {
|
605 |
+
// Resize plane and fill with ones...
|
606 |
+
ones = at::ones({height_out, width_out}, input.options());
|
607 |
+
}
|
608 |
+
|
609 |
+
grad_input = grad_input.view({batch, channels, height, width});
|
610 |
+
columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out},
|
611 |
+
input.options());
|
612 |
+
|
613 |
+
grad_output =
|
614 |
+
grad_output.view({grad_output.size(0), group, grad_output.size(1) / group,
|
615 |
+
grad_output.size(2), grad_output.size(3)});
|
616 |
+
|
617 |
+
for (int b = 0; b < batch; b++) {
|
618 |
+
// divide int group
|
619 |
+
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
620 |
+
weight = weight.view({group, weight.size(0) / group, weight.size(1),
|
621 |
+
weight.size(2), weight.size(3)});
|
622 |
+
|
623 |
+
for (int g = 0; g < group; g++) {
|
624 |
+
columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
|
625 |
+
grad_output[b][g].flatten(1), 0.0f, 1.0f);
|
626 |
+
}
|
627 |
+
|
628 |
+
columns =
|
629 |
+
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
630 |
+
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
|
631 |
+
weight.size(3), weight.size(4)});
|
632 |
+
|
633 |
+
// gradient w.r.t. input coordinate data
|
634 |
+
modulated_deformable_col2im_coord_cuda(
|
635 |
+
columns, input[b], offset[b], mask[b], 1, channels, height, width,
|
636 |
+
height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h,
|
637 |
+
stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b],
|
638 |
+
grad_mask[b]);
|
639 |
+
// gradient w.r.t. input data
|
640 |
+
modulated_deformable_col2im_cuda(
|
641 |
+
columns, offset[b], mask[b], 1, channels, height, width, height_out,
|
642 |
+
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
|
643 |
+
dilation_h, dilation_w, deformable_group, grad_input[b]);
|
644 |
+
|
645 |
+
// gradient w.r.t. weight, dWeight should accumulate across the batch and
|
646 |
+
// group
|
647 |
+
modulated_deformable_im2col_cuda(
|
648 |
+
input[b], offset[b], mask[b], 1, channels, height, width, height_out,
|
649 |
+
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
|
650 |
+
dilation_h, dilation_w, deformable_group, columns);
|
651 |
+
|
652 |
+
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
653 |
+
grad_weight = grad_weight.view({group, grad_weight.size(0) / group,
|
654 |
+
grad_weight.size(1), grad_weight.size(2),
|
655 |
+
grad_weight.size(3)});
|
656 |
+
if (with_bias)
|
657 |
+
grad_bias = grad_bias.view({group, grad_bias.size(0) / group});
|
658 |
+
|
659 |
+
for (int g = 0; g < group; g++) {
|
660 |
+
grad_weight[g] =
|
661 |
+
grad_weight[g]
|
662 |
+
.flatten(1)
|
663 |
+
.addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1))
|
664 |
+
.view_as(grad_weight[g]);
|
665 |
+
if (with_bias) {
|
666 |
+
grad_bias[g] =
|
667 |
+
grad_bias[g]
|
668 |
+
.view({-1, 1})
|
669 |
+
.addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1}))
|
670 |
+
.view(-1);
|
671 |
+
}
|
672 |
+
}
|
673 |
+
|
674 |
+
columns =
|
675 |
+
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
676 |
+
grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1),
|
677 |
+
grad_weight.size(2), grad_weight.size(3),
|
678 |
+
grad_weight.size(4)});
|
679 |
+
if (with_bias)
|
680 |
+
grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)});
|
681 |
+
}
|
682 |
+
grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1),
|
683 |
+
grad_output.size(2), grad_output.size(3),
|
684 |
+
grad_output.size(4)});
|
685 |
+
}
|