alexnasa commited on
Commit
295978e
·
verified ·
1 Parent(s): 7094d74

Upload 54 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +11 -0
  2. LICENSE +201 -0
  3. README.md +13 -12
  4. app.py +435 -0
  5. assets/teaser.png +3 -0
  6. examples/amber.png +3 -0
  7. examples/armour.png +3 -0
  8. examples/art.wav +3 -0
  9. examples/chris.png +3 -0
  10. examples/dream.mp3 +3 -0
  11. examples/fictional.wav +3 -0
  12. examples/fight.wav +3 -0
  13. examples/jacket.png +3 -0
  14. examples/naomi.png +3 -0
  15. examples/science.wav +0 -0
  16. examples/vangogh.jpg +3 -0
  17. humo/common/__init__.py +0 -0
  18. humo/common/config.py +107 -0
  19. humo/common/distributed/__init__.py +41 -0
  20. humo/common/distributed/advanced.py +484 -0
  21. humo/common/distributed/basic.py +143 -0
  22. humo/common/logger.py +44 -0
  23. humo/configs/inference/generate.yaml +78 -0
  24. humo/configs/inference/generate_1_7B.yaml +76 -0
  25. humo/configs/models/Wan_1.3B.yaml +17 -0
  26. humo/configs/models/Wan_1.3B_I2V.yaml +18 -0
  27. humo/configs/models/Wan_14B.yaml +17 -0
  28. humo/configs/models/Wan_14B_I2V.yaml +18 -0
  29. humo/generate.py +984 -0
  30. humo/generate_1_7B.py +622 -0
  31. humo/models/audio/audio_proj.py +87 -0
  32. humo/models/distributed/__init__.py +0 -0
  33. humo/models/distributed/dit_ulysses_sequence_parallel.py +270 -0
  34. humo/models/distributed/fsdp.py +42 -0
  35. humo/models/text/encoder.py +173 -0
  36. humo/models/utils/fm_solvers.py +857 -0
  37. humo/models/utils/fm_solvers_unipc.py +800 -0
  38. humo/models/utils/utils.py +58 -0
  39. humo/models/wan_modules/__init__.py +16 -0
  40. humo/models/wan_modules/attention.py +256 -0
  41. humo/models/wan_modules/clip.py +542 -0
  42. humo/models/wan_modules/model.py +619 -0
  43. humo/models/wan_modules/model_humo.py +803 -0
  44. humo/models/wan_modules/t5.py +525 -0
  45. humo/models/wan_modules/tokenizers.py +82 -0
  46. humo/models/wan_modules/vae.py +666 -0
  47. humo/models/wan_modules/xlm_roberta.py +170 -0
  48. humo/utils/audio_processor_whisper.py +173 -0
  49. humo/utils/wav2vec.py +218 -0
  50. main.py +28 -0
.gitattributes CHANGED
@@ -33,3 +33,14 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/teaser.png filter=lfs diff=lfs merge=lfs -text
37
+ examples/amber.png filter=lfs diff=lfs merge=lfs -text
38
+ examples/armour.png filter=lfs diff=lfs merge=lfs -text
39
+ examples/art.wav filter=lfs diff=lfs merge=lfs -text
40
+ examples/chris.png filter=lfs diff=lfs merge=lfs -text
41
+ examples/dream.mp3 filter=lfs diff=lfs merge=lfs -text
42
+ examples/fictional.wav filter=lfs diff=lfs merge=lfs -text
43
+ examples/fight.wav filter=lfs diff=lfs merge=lfs -text
44
+ examples/jacket.png filter=lfs diff=lfs merge=lfs -text
45
+ examples/naomi.png filter=lfs diff=lfs merge=lfs -text
46
+ examples/vangogh.jpg filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "{}"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2025 Bytedance
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,12 +1,13 @@
1
- ---
2
- title: HuMo Local
3
- emoji: 🌍
4
- colorFrom: green
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 5.49.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
1
+ ---
2
+ title: HuMo [Local]
3
+ emoji: 👩‍🦱
4
+ colorFrom: purple
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 5.47.2
8
+ app_file: app.py
9
+ pinned: false
10
+ short_description: Reference based video generation
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ import sys
4
+ import os
5
+ import subprocess
6
+ import uuid
7
+ import shutil
8
+
9
+
10
+
11
+ from huggingface_hub import snapshot_download, list_repo_files, hf_hub_download
12
+ import importlib, site
13
+
14
+
15
+ # Re-discover all .pth/.egg-link files
16
+ for sitedir in site.getsitepackages():
17
+ site.addsitedir(sitedir)
18
+
19
+ # Clear caches so importlib will pick up new modules
20
+ importlib.invalidate_caches()
21
+
22
+ def sh(cmd): subprocess.check_call(cmd, shell=True)
23
+
24
+ flash_attention_installed = False
25
+
26
+ try:
27
+ flash_attention_wheel = hf_hub_download(
28
+ repo_id="alexnasa/flash-attn-3",
29
+ repo_type="model",
30
+ filename="128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl",
31
+ )
32
+
33
+ sh(f"pip install {flash_attention_wheel}")
34
+ print("Attempting to download and install FlashAttention wheel...")
35
+ # sh("pip install flash-attn")
36
+ sh("pip install --no-build-isolation transformer_engine-2.5.0+f05f12c9-cp310-cp310-linux_x86_64.whl")
37
+
38
+ # tell Python to re-scan site-packages now that the egg-link exists
39
+ import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.invalidate_caches()
40
+
41
+ flash_attention_installed = True
42
+
43
+ except Exception as e:
44
+ print(f"⚠️ Could not install FlashAttention: {e}")
45
+ print("Continuing without FlashAttention...")
46
+
47
+ try:
48
+ te_wheel = hf_hub_download(
49
+ repo_id="alexnasa/transformer_engine_wheels",
50
+ repo_type="model",
51
+ filename="transformer_engine-2.5.0+f05f12c9-cp310-cp310-linux_x86_64.whl",
52
+ )
53
+
54
+ sh(f"pip install {te_wheel}")
55
+ print("Attempting to download and install Transformer Engine wheel...")
56
+
57
+ # tell Python to re-scan site-packages now that the egg-link exists
58
+ import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.invalidate_caches()
59
+
60
+ except Exception as e:
61
+ print(f"⚠️ Could not install Transformer Engine : {e}")
62
+ print("Continuing without Transformer Engine ...")
63
+
64
+ import torch
65
+ print(f"Torch version: {torch.__version__}")
66
+ print(f"FlashAttention available: {flash_attention_installed}")
67
+
68
+ import tempfile
69
+ from pathlib import Path
70
+ from torch._inductor.runtime.runtime_utils import cache_dir as _inductor_cache_dir
71
+ from huggingface_hub import HfApi
72
+
73
+
74
+ snapshot_download(repo_id="bytedance-research/HuMo", local_dir="./weights/HuMo")
75
+ snapshot_download(repo_id="Wan-AI/Wan2.1-T2V-1.3B", local_dir="./weights/Wan2.1-T2V-1.3B")
76
+ snapshot_download(repo_id="openai/whisper-large-v3", local_dir="./weights/whisper-large-v3")
77
+
78
+ os.environ["PROCESSED_RESULTS"] = f"{os.getcwd()}/proprocess_results"
79
+
80
+ path_to_insert = "humo"
81
+ if path_to_insert not in sys.path:
82
+ sys.path.insert(0, path_to_insert)
83
+
84
+ from common.config import load_config, create_object
85
+
86
+ config = load_config(
87
+ "./humo/configs/inference/generate.yaml",
88
+ [
89
+ "dit.sp_size=1",
90
+ "generation.frames=97",
91
+ "generation.scale_t=5.5",
92
+ "generation.scale_a=5.0",
93
+ "generation.mode=TIA",
94
+ "generation.height=480",
95
+ "generation.width=832",
96
+ ],
97
+ )
98
+ runner = create_object(config)
99
+
100
+
101
+ os.environ.setdefault("TORCHINDUCTOR_CACHE_DIR", f"{os.getcwd()}/torchinductor_space") # or another writable path
102
+
103
+ def restore_inductor_cache_from_hub(repo_id: str, filename: str = "torch_compile_cache.zip",
104
+ path_in_repo: str = "inductor_cache", repo_type: str = "model",
105
+ hf_token: str | None = None):
106
+ cache_root = Path(_inductor_cache_dir()).resolve()
107
+ cache_root.mkdir(parents=True, exist_ok=True)
108
+ zip_path = hf_hub_download(repo_id=repo_id, filename=f"{path_in_repo}/{filename}",
109
+ repo_type=repo_type, token=hf_token)
110
+ shutil.unpack_archive(zip_path, extract_dir=str(cache_root))
111
+ print(f"✓ Restored cache into {cache_root}")
112
+
113
+
114
+ # restore_inductor_cache_from_hub("alexnasa/humo-compiled")
115
+
116
+
117
+ def get_duration(prompt_text, steps, image_file, audio_file_path, tea_cache_l1_thresh, max_duration, session_id):
118
+
119
+ return calculate_required_time(steps, max_duration)
120
+
121
+ def calculate_required_time(steps, max_duration):
122
+
123
+ warmup_s = 60
124
+
125
+ max_duration_duration_mapping = {
126
+ 1: 8,
127
+ 2: 8,
128
+ 3: 11,
129
+ 4: 20,
130
+ 5: 30,
131
+ }
132
+ each_step_s = max_duration_duration_mapping[max_duration]
133
+ duration_s = (each_step_s * steps) + warmup_s
134
+
135
+ print(f'estimated duration:{duration_s}')
136
+
137
+ return int(duration_s)
138
+
139
+ def get_required_time_string(steps, max_duration):
140
+
141
+ duration_s = calculate_required_time(steps, max_duration)
142
+ duration_m = duration_s / 60
143
+
144
+ return f"<center>⌚ Zero GPU Required: ~{duration_s}.0s ({duration_m:.1f} mins)</center>"
145
+
146
+ def update_required_time(steps, max_duration):
147
+
148
+ return get_required_time_string(steps, max_duration)
149
+
150
+
151
+ def generate_scene(prompt_text, steps, image_paths, audio_file_path, tea_cache_l1_thresh, max_duration = 2, session_id = None):
152
+
153
+ print(image_paths)
154
+ prompt_text_check = (prompt_text or "").strip()
155
+ if not prompt_text_check:
156
+ raise gr.Error("Please enter a prompt.")
157
+
158
+ if not audio_file_path and not image_paths:
159
+ raise gr.Error("Please provide a reference image or a lipsync audio.")
160
+
161
+ return run_pipeline(prompt_text, steps, image_paths, audio_file_path, tea_cache_l1_thresh, max_duration, session_id)
162
+
163
+
164
+
165
+ def upload_inductor_cache_to_hub(
166
+ repo_id: str,
167
+ path_in_repo: str = "inductor_cache",
168
+ repo_type: str = "model", # or "dataset" if you prefer
169
+ hf_token: str | None = None,
170
+ ):
171
+ """
172
+ Zips the current TorchInductor cache and uploads it to the given repo path.
173
+ Assumes the model was already run once with torch.compile() so the cache exists.
174
+ """
175
+
176
+ cache_dir = Path(_inductor_cache_dir()).resolve()
177
+ if not cache_dir.exists():
178
+ raise FileNotFoundError(f"TorchInductor cache not found at {cache_dir}. "
179
+ "Run a compiled model once to populate it.")
180
+
181
+ # Create a zip archive of the entire cache directory
182
+ with tempfile.TemporaryDirectory() as tmpdir:
183
+ archive_base = Path(tmpdir) / "torch_compile_cache"
184
+ archive_path = shutil.make_archive(str(archive_base), "zip", root_dir=str(cache_dir))
185
+ archive_path = Path(archive_path)
186
+
187
+ # Upload to Hub
188
+ api = HfApi(token=hf_token)
189
+ api.create_repo(repo_id=repo_id, repo_type=repo_type, exist_ok=True)
190
+ # Put each artifact under path_in_repo, including a tiny metadata stamp for traceability
191
+ # Upload the zip
192
+ dest_path = f"{path_in_repo}/{archive_path.name}"
193
+ api.upload_file(
194
+ path_or_fileobj=str(archive_path),
195
+ path_in_repo=dest_path,
196
+ repo_id=repo_id,
197
+ repo_type=repo_type,
198
+ )
199
+ # Upload a small metadata file (optional but handy)
200
+ meta_txt = (
201
+ f"pytorch={torch.__version__}\n"
202
+ f"inductor_cache_dir={cache_dir}\n"
203
+ f"cuda_available={torch.cuda.is_available()}\n"
204
+ f"cuda_device={torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'cpu'}\n"
205
+ )
206
+ api.upload_file(
207
+ path_or_fileobj=meta_txt.encode(),
208
+ path_in_repo=f"{path_in_repo}/INDUCTOR_CACHE_METADATA.txt",
209
+ repo_id=repo_id,
210
+ repo_type=repo_type,
211
+ )
212
+
213
+ print("✔ Uploaded TorchInductor cache to the Hub.")
214
+
215
+
216
+ @spaces.GPU(duration=get_duration)
217
+ def run_pipeline(prompt_text, steps, image_paths, audio_file_path, tea_cache_l1_thresh = 0.0, max_duration = 2, session_id = None):
218
+
219
+ if session_id is None:
220
+ session_id = uuid.uuid4().hex
221
+
222
+ inference_mode = "TIA"
223
+
224
+ # Validate inputs
225
+ prompt_text = (prompt_text or "").strip()
226
+ if not prompt_text:
227
+ raise gr.Error("Please enter a prompt.")
228
+
229
+ if not audio_file_path and not image_paths:
230
+ raise gr.Error("Please provide a reference image or a lipsync audio.")
231
+
232
+ if not audio_file_path:
233
+ inference_mode = "TI"
234
+ audio_path = None
235
+ else:
236
+ audio_path = audio_file_path if isinstance(audio_file_path, str) else getattr(audio_file_path, "name", str(audio_file_path))
237
+
238
+ if not image_paths:
239
+ inference_mode = "TA"
240
+ img_paths = None
241
+ else:
242
+ img_paths = [image_data[0] for image_data in image_paths]
243
+
244
+
245
+ # Prepare output
246
+ output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
247
+ os.makedirs(output_dir, exist_ok=True)
248
+
249
+ # Random filename
250
+ filename = f"gen_{uuid.uuid4().hex[:10]}"
251
+ width, height = 832, 480
252
+
253
+ duration_frame_mapping = {
254
+ 1:25,
255
+ 2:45,
256
+ 3:70,
257
+ 4:97,
258
+ 5:129
259
+ }
260
+
261
+ # Run inference
262
+ runner.inference_loop(
263
+ prompt_text,
264
+ img_paths,
265
+ audio_path,
266
+ output_dir,
267
+ filename,
268
+ inference_mode,
269
+ width,
270
+ height,
271
+ steps,
272
+ frames = int(duration_frame_mapping[max_duration]),
273
+ tea_cache_l1_thresh = tea_cache_l1_thresh,
274
+ )
275
+
276
+ # Return resulting video path
277
+ video_path = os.path.join(output_dir, f"{filename}.mp4")
278
+ if os.path.exists(video_path):
279
+
280
+ # upload_inductor_cache_to_hub("alexnasa/humo-compiled")
281
+
282
+ return video_path
283
+ else:
284
+ candidates = [os.path.join(output_dir, f) for f in os.listdir(output_dir) if f.endswith(".mp4")]
285
+ if candidates:
286
+ return max(candidates, key=lambda p: os.path.getmtime(p))
287
+ return None
288
+
289
+ css = """
290
+ #col-container {
291
+ margin: 0 auto;
292
+ width: 100%;
293
+ max-width: 720px;
294
+ }
295
+ """
296
+
297
+ def cleanup(request: gr.Request):
298
+
299
+ sid = request.session_hash
300
+ if sid:
301
+ d1 = os.path.join(os.environ["PROCESSED_RESULTS"], sid)
302
+ shutil.rmtree(d1, ignore_errors=True)
303
+
304
+ def start_session(request: gr.Request):
305
+
306
+ return request.session_hash
307
+
308
+ with gr.Blocks(css=css) as demo:
309
+
310
+ session_state = gr.State()
311
+ demo.load(start_session, outputs=[session_state])
312
+
313
+ with gr.Sidebar(width=400):
314
+
315
+
316
+ gr.HTML(
317
+ """
318
+ <div style="text-align: center;">
319
+ <p style="font-size:16px; display: inline; margin: 0;">
320
+ <strong>HuMo</strong> – Human-Centric Video Generation via Collaborative Multi-Modal Conditioning
321
+ </p>
322
+ <a href="https://github.com/Phantom-video/HuMo" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
323
+ [Github]
324
+ </a>
325
+ </div>
326
+ """
327
+ )
328
+
329
+ gr.Markdown("**REFERENCE IMAGES**")
330
+
331
+ img_input = gr.Gallery(
332
+ show_label=False,
333
+ label="",
334
+ interactive=True,
335
+ rows=1, columns=3, object_fit="contain", height="280",
336
+ file_types=['image']
337
+ )
338
+
339
+ gr.Markdown("**LIPSYNC AUDIO**")
340
+
341
+ audio_input = gr.Audio(
342
+ sources=["upload"],
343
+ show_label=False,
344
+ type="filepath",
345
+ )
346
+
347
+ gr.Markdown("**SETTINGS**")
348
+
349
+ default_steps = 10
350
+ default_max_duration = 2
351
+
352
+ max_duration = gr.Slider(minimum=2, maximum=5, value=default_max_duration, step=1, label="Max Duration")
353
+ steps_input = gr.Slider(minimum=5, maximum=50, value=default_steps, step=5, label="Diffusion Steps")
354
+ tea_cache_l1_thresh = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.01, label="Cache", visible=False)
355
+
356
+
357
+
358
+ with gr.Column(elem_id="col-container"):
359
+
360
+ gr.HTML(
361
+ """
362
+ <div style="text-align: center;">
363
+ <strong>HF Space by:</strong>
364
+ <a href="https://twitter.com/alexandernasa/" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
365
+ <img src="https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=Follow Me" alt="GitHub Repo">
366
+ </a>
367
+ </div>
368
+ """
369
+ )
370
+
371
+ video_output = gr.Video(show_label=False)
372
+
373
+ gr.Markdown("<center><h2>PROMPT</h2></center>")
374
+
375
+ prompt_tb = gr.Textbox(
376
+ show_label=False,
377
+ lines=5,
378
+ placeholder="Describe the scene and the person talking....",
379
+ )
380
+
381
+ gr.Markdown("")
382
+ time_required = gr.Markdown(get_required_time_string(default_steps, default_max_duration))
383
+ run_btn = gr.Button("🎬 Action", variant="primary")
384
+
385
+ gr.Examples(
386
+ examples=[
387
+
388
+ [
389
+ "A handheld tracking shot follows a female warrior walking through a cave. Her determined eyes are locked straight ahead. She speaks with intensity.",
390
+ 5,
391
+ ["./examples/naomi.png"],
392
+ "./examples/dream.mp3",
393
+ ],
394
+
395
+ [
396
+ "A reddish-brown haired and bearded man sits pensively against swirling blue-and-white brushstrokes, dressed in a blue coat and dark waistcoat. The artistic backdrop and his thoughtful pose evoke a Post-Impressionist style in a studio-like setting.",
397
+ 10,
398
+ ["./examples/vangogh.jpg"],
399
+ "./examples/art.wav",
400
+ ],
401
+
402
+ [
403
+ "A handheld tracking shot follows a female through a science lab. Her determined eyes are locked straight ahead. The clip is in black and white and patchy as she is explaining something to someone standing opposite her",
404
+ 10,
405
+ ["./examples/naomi.png"],
406
+ "./examples/science.wav",
407
+ ],
408
+
409
+ [
410
+ "A woman with long, wavy dark hair looking at a person sitting opposite her whilst holding a book, wearing a leather jacket, long-sleeved jacket with a semi purple color one seen on a photo. Warm, window-like light bathes her figure, highlighting the outfit's elegant design and her graceful movements.",
411
+ 50,
412
+ ["./examples/amber.png", "./examples/jacket.png"],
413
+ "./examples/fictional.mp3",
414
+ ],
415
+
416
+ ],
417
+ inputs=[prompt_tb, steps_input, img_input, audio_input],
418
+ outputs=[video_output],
419
+ fn=run_pipeline,
420
+ cache_examples=True,
421
+ )
422
+ max_duration.change(update_required_time, [steps_input, max_duration], time_required)
423
+ steps_input.change(update_required_time, [steps_input, max_duration], time_required)
424
+
425
+ run_btn.click(
426
+ fn=generate_scene,
427
+ inputs=[prompt_tb, steps_input, img_input, audio_input, tea_cache_l1_thresh, max_duration, session_state],
428
+ outputs=[video_output],
429
+ )
430
+
431
+
432
+ if __name__ == "__main__":
433
+ demo.unload(cleanup)
434
+ demo.queue()
435
+ demo.launch(ssr_mode=False)
assets/teaser.png ADDED

Git LFS Details

  • SHA256: 722d29d27fb89a6e1ebebef233492f0c06c25b09a0bdb8e723ef567e778bcf34
  • Pointer size: 132 Bytes
  • Size of remote file: 5.83 MB
examples/amber.png ADDED

Git LFS Details

  • SHA256: 6ce1a891ea71b184eeb4bc768322006c4ecccc8e063b2a0afd38829c6e975f03
  • Pointer size: 132 Bytes
  • Size of remote file: 2.4 MB
examples/armour.png ADDED

Git LFS Details

  • SHA256: 192dd4b1c80c9ddacb8678962b5a1c04855d44c9877810aba032761ce50052a2
  • Pointer size: 132 Bytes
  • Size of remote file: 1.79 MB
examples/art.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:72c75df8e93a107e262ea9b002a66e72d3c1cd2084bce1474a31d8afffd0b651
3
+ size 114254
examples/chris.png ADDED

Git LFS Details

  • SHA256: a3100088e3247d8ecaf1ded2e2417e70d6ad34d24ce7f6e7551cb7fd24c91dcf
  • Pointer size: 132 Bytes
  • Size of remote file: 2.05 MB
examples/dream.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:27248fd9e8f29bd60ccb1163b8df3c6f2630734f358aa3362ffe67e8148e0eb1
3
+ size 108275
examples/fictional.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:31b550e6433ea44a0642dee90c326664ff4f568fec184170001f834597b3ad23
3
+ size 167084
examples/fight.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8dbee86c85e992ac6d17820a3730bf753fc9bf5bac6b8a470f84b7e98a64221a
3
+ size 264782
examples/jacket.png ADDED

Git LFS Details

  • SHA256: e80a02659148e3364eaa46e46dd64a268f04c7f7eeed0e8f203b6b848738666a
  • Pointer size: 132 Bytes
  • Size of remote file: 2.57 MB
examples/naomi.png ADDED

Git LFS Details

  • SHA256: 5666cd6253658e76695e8529b28d730c74f9e63a1afca07e47505e59b24e7656
  • Pointer size: 132 Bytes
  • Size of remote file: 1.18 MB
examples/science.wav ADDED
Binary file (82.5 kB). View file
 
examples/vangogh.jpg ADDED

Git LFS Details

  • SHA256: 1ae77da89271f32196ad9e8a915e20a7f71e9a84b78ecec3aa1dfcb4e4b39de1
  • Pointer size: 131 Bytes
  • Size of remote file: 139 kB
humo/common/__init__.py ADDED
File without changes
humo/common/config.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ # Codes adapted from [SeedVR]
13
+ # https://github.com/ByteDance-Seed/SeedVR/blob/main/common/config.py
14
+
15
+ """
16
+ Configuration utility functions
17
+ """
18
+
19
+ import importlib
20
+ from typing import Any, Callable, List, Union
21
+ from omegaconf import DictConfig, ListConfig, OmegaConf
22
+
23
+ OmegaConf.register_new_resolver("eval", eval)
24
+
25
+
26
+ def load_config(path: str, argv: List[str] = None) -> Union[DictConfig, ListConfig]:
27
+ """
28
+ Load a configuration. Will resolve inheritance.
29
+ """
30
+ config = OmegaConf.load(path)
31
+ if argv is not None:
32
+ config_argv = OmegaConf.from_dotlist(argv)
33
+ config = OmegaConf.merge(config, config_argv)
34
+ config = resolve_recursive(config, resolve_inheritance)
35
+ return config
36
+
37
+
38
+ def resolve_recursive(
39
+ config: Any,
40
+ resolver: Callable[[Union[DictConfig, ListConfig]], Union[DictConfig, ListConfig]],
41
+ ) -> Any:
42
+ config = resolver(config)
43
+ if isinstance(config, DictConfig):
44
+ for k in config.keys():
45
+ v = config.get(k)
46
+ if isinstance(v, (DictConfig, ListConfig)):
47
+ config[k] = resolve_recursive(v, resolver)
48
+ if isinstance(config, ListConfig):
49
+ for i in range(len(config)):
50
+ v = config.get(i)
51
+ if isinstance(v, (DictConfig, ListConfig)):
52
+ config[i] = resolve_recursive(v, resolver)
53
+ return config
54
+
55
+
56
+ def resolve_inheritance(config: Union[DictConfig, ListConfig]) -> Any:
57
+ """
58
+ Recursively resolve inheritance if the config contains:
59
+ __inherit__: path/to/parent.yaml.
60
+ """
61
+ if isinstance(config, DictConfig):
62
+ inherit = config.pop("__inherit__", None)
63
+ if inherit:
64
+ assert isinstance(inherit, str)
65
+ inherit = load_config(inherit)
66
+ if len(config.keys()) > 0:
67
+ config = OmegaConf.merge(inherit, config)
68
+ else:
69
+ config = inherit
70
+ return config
71
+
72
+
73
+ def import_item(path: str, name: str) -> Any:
74
+ """
75
+ Import a python item. Example: import_item("path.to.file", "MyClass") -> MyClass
76
+ """
77
+ return getattr(importlib.import_module(path), name)
78
+
79
+
80
+ def create_object(config: DictConfig) -> Any:
81
+ """
82
+ Create an object from config.
83
+ The config is expected to contains the following:
84
+ __object__:
85
+ path: path.to.module
86
+ name: MyClass
87
+ args: as_config | as_params (default to as_config)
88
+ """
89
+ item = import_item(
90
+ path=config.__object__.path,
91
+ name=config.__object__.name,
92
+ )
93
+ args = config.__object__.get("args", "as_config")
94
+ if args == "as_config":
95
+ return item(config)
96
+ if args == "as_params":
97
+ config = OmegaConf.to_object(config)
98
+ config.pop("__object__")
99
+ return item(**config)
100
+ raise NotImplementedError(f"Unknown args type: {args}")
101
+
102
+
103
+ def create_dataset(path: str, *args, **kwargs) -> Any:
104
+ """
105
+ Create a dataset. Requires the file to contain a "create_dataset" function.
106
+ """
107
+ return import_item(path, "create_dataset")(*args, **kwargs)
humo/common/distributed/__init__.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ # Codes adapted from [SeedVR]
13
+ # https://github.com/ByteDance-Seed/SeedVR/tree/main/common/distributed
14
+
15
+ """
16
+ Distributed package.
17
+ """
18
+
19
+ from .basic import (
20
+ barrier_if_distributed,
21
+ convert_to_ddp,
22
+ get_device,
23
+ get_global_rank,
24
+ get_local_rank,
25
+ get_world_size,
26
+ init_torch,
27
+ meta_param_init_fn,
28
+ meta_non_persistent_buffer_init_fn
29
+ )
30
+
31
+ __all__ = [
32
+ "barrier_if_distributed",
33
+ "convert_to_ddp",
34
+ "get_device",
35
+ "get_global_rank",
36
+ "get_local_rank",
37
+ "get_world_size",
38
+ "init_torch",
39
+ "meta_param_init_fn",
40
+ "meta_non_persistent_buffer_init_fn",
41
+ ]
humo/common/distributed/advanced.py ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ # Codes adapted from [SeedVR]
13
+ # https://github.com/ByteDance-Seed/SeedVR/tree/main/common/distributed
14
+
15
+ """
16
+ Advanced distributed functions for sequence parallel.
17
+ """
18
+
19
+ import torch
20
+ from typing import Any, List, Optional, Tuple, Union
21
+ import torch.distributed as dist
22
+ from torch import Tensor
23
+
24
+ from .basic import get_global_rank, get_world_size
25
+
26
+
27
+ _DATA_PARALLEL_GROUP = None
28
+ _SEQUENCE_PARALLEL_GROUP = None
29
+ _SEQUENCE_PARALLEL_CPU_GROUP = None
30
+
31
+
32
+ _CFG_PARALLEL_GROUP = None
33
+ _CFG_PARALLEL_CPU_GROUP = None
34
+
35
+ def get_data_parallel_group() -> Optional[dist.ProcessGroup]:
36
+ """
37
+ Get data parallel process group.
38
+ """
39
+ return _DATA_PARALLEL_GROUP
40
+
41
+
42
+ def get_sequence_parallel_group() -> Optional[dist.ProcessGroup]:
43
+ """
44
+ Get sequence parallel process group.
45
+ """
46
+ return _SEQUENCE_PARALLEL_GROUP
47
+
48
+
49
+ def get_sequence_parallel_cpu_group() -> Optional[dist.ProcessGroup]:
50
+ """
51
+ Get sequence parallel CPU process group.
52
+ """
53
+ return _SEQUENCE_PARALLEL_CPU_GROUP
54
+
55
+
56
+ def get_data_parallel_rank() -> int:
57
+ """
58
+ Get data parallel rank.
59
+ """
60
+ group = get_data_parallel_group()
61
+ return dist.get_rank(group) if group else get_global_rank()
62
+
63
+
64
+ def get_data_parallel_world_size() -> int:
65
+ """
66
+ Get data parallel world size.
67
+ """
68
+ group = get_data_parallel_group()
69
+ return dist.get_world_size(group) if group else get_world_size()
70
+
71
+
72
+ def get_sequence_parallel_rank() -> int:
73
+ """
74
+ Get sequence parallel rank.
75
+ """
76
+ group = get_sequence_parallel_group()
77
+ return dist.get_rank(group) if group else 0
78
+
79
+
80
+ def get_sequence_parallel_world_size() -> int:
81
+ """
82
+ Get sequence parallel world size.
83
+ """
84
+ group = get_sequence_parallel_group()
85
+ return dist.get_world_size(group) if group else 1
86
+
87
+
88
+ def init_unified_parallel(unified_parallel_size):
89
+ global _SEQUENCE_PARALLEL_GROUP
90
+ global _SEQUENCE_PARALLEL_CPU_GROUP
91
+
92
+ if unified_parallel_size == 1:
93
+ return
94
+
95
+ assert dist.is_initialized()
96
+ world_size = dist.get_world_size()
97
+ rank = dist.get_rank()
98
+ assert world_size % unified_parallel_size == 0
99
+ data_parallel_size = world_size // unified_parallel_size
100
+
101
+ for i in range(data_parallel_size):
102
+ # build unified parallel group
103
+ start_rank = i * unified_parallel_size
104
+ end_rank = start_rank + unified_parallel_size
105
+ unified_parallel_ranks = range(start_rank, end_rank)
106
+ unified_parallel_group = dist.new_group(unified_parallel_ranks)
107
+ unified_parallel_cpu_group = dist.new_group(unified_parallel_ranks, backend="gloo")
108
+ if rank in unified_parallel_ranks:
109
+ _SEQUENCE_PARALLEL_GROUP = unified_parallel_group
110
+ _SEQUENCE_PARALLEL_CPU_GROUP = unified_parallel_cpu_group
111
+
112
+
113
+ def get_unified_parallel_group():
114
+ global _SEQUENCE_PARALLEL_GROUP
115
+ return _SEQUENCE_PARALLEL_GROUP
116
+
117
+
118
+ def get_unified_parallel_cpu_group():
119
+ global _SEQUENCE_PARALLEL_CPU_GROUP
120
+ return _SEQUENCE_PARALLEL_CPU_GROUP
121
+
122
+
123
+ def get_unified_parallel_rank():
124
+ group = get_unified_parallel_group()
125
+ return dist.get_rank(group) if group else 0
126
+
127
+
128
+ def get_unified_parallel_world_size():
129
+ group = get_unified_parallel_group()
130
+ return dist.get_world_size(group) if group else 1
131
+
132
+
133
+ def is_unified_parallel_initialized():
134
+ group = get_unified_parallel_group()
135
+ return group is not None
136
+
137
+
138
+ def pad_tensor(x: Tensor, dim: int, padding_size: int):
139
+ shape = list(x.shape)
140
+ shape[dim] = padding_size
141
+ pad = torch.zeros(shape, dtype=x.dtype, device=x.device)
142
+ return torch.cat([x, pad], dim=dim)
143
+
144
+
145
+ class Slice(torch.autograd.Function):
146
+ @staticmethod
147
+ def forward(ctx: Any, group: dist.ProcessGroup, local_input: Tensor, dim: int, scale_grad: bool) -> Tensor:
148
+ ctx.group = group
149
+ ctx.rank = dist.get_rank(group)
150
+ seq_world_size = dist.get_world_size(group)
151
+ ctx.seq_world_size = seq_world_size
152
+ ctx.dim = dim
153
+ ctx.scale_grad = scale_grad
154
+ dim_size = local_input.shape[dim]
155
+ if not ctx.group:
156
+ return local_input
157
+ return local_input.split(dim_size // seq_world_size, dim=dim)[ctx.rank].contiguous()
158
+
159
+ @staticmethod
160
+ def backward(ctx: Any, grad_output: Tensor) -> Tuple[None, Tensor, None]:
161
+ if not ctx.group:
162
+ return None, grad_output, None, None
163
+ dim_size = list(grad_output.size())
164
+ split_size = dim_size[0]
165
+ dim_size[0] = dim_size[0] * ctx.seq_world_size
166
+ output = torch.empty(dim_size, dtype=grad_output.dtype, device=torch.cuda.current_device())
167
+ dist.all_gather_into_tensor(output, grad_output, group=ctx.group)
168
+ if ctx.scale_grad:
169
+ output = output / ctx.seq_world_size
170
+ return (None, torch.cat(output.split(split_size), dim=ctx.dim), None, None)
171
+
172
+
173
+ def gather_outputs(
174
+ x: Tensor,
175
+ gather_dim: int,
176
+ padding_dim: Optional[int] = None,
177
+ unpad_dim_size: Optional[int] = None,
178
+ scale_grad=True,
179
+ ):
180
+ """
181
+ A func to gather the outputs for the model result in sequence parallel
182
+ """
183
+ group = get_unified_parallel_group()
184
+ if not group:
185
+ return x
186
+ x = Gather.apply(group, x, gather_dim, scale_grad)
187
+ if padding_dim is not None:
188
+ x = unpadding_tensor_for_seqeunce_parallel(x, padding_dim, unpad_dim_size)
189
+ return x
190
+
191
+
192
+ def unpadding_tensor_for_seqeunce_parallel(x: Tensor, dim: int, unpadded_dim_size: int):
193
+ """
194
+ A func to remove the padding part of the tensor based on its original shape
195
+ """
196
+ group = get_unified_parallel_group()
197
+ if group is None:
198
+ return x
199
+ sp_world = get_unified_parallel_world_size()
200
+ if unpadded_dim_size % sp_world == 0:
201
+ return x
202
+ padding_size = sp_world - (unpadded_dim_size % sp_world)
203
+ assert (padding_size + unpadded_dim_size) % sp_world == 0
204
+ return unpad_tensor(x, dim=dim, padding_size=padding_size)
205
+
206
+
207
+ def gather_seq_scatter_heads_qkv(
208
+ qkv_tensor: Tensor,
209
+ seq_dim: int,
210
+ unpadded_dim_size: Optional[int] = None,
211
+ restore_shape: bool = True,
212
+ async_op: bool = False,
213
+ ):
214
+ """
215
+ A func to sync splited qkv tensor
216
+ qkv_tensor: the tensor we want to do alltoall with. The last dim must
217
+ be the projection_idx, which we will split into 3 part. After
218
+ spliting, the gather idx will be projecttion_idx + 1
219
+ seq_dim: gather_dim for all2all comm
220
+ restore_shape: if True, output will has the same shape length as input
221
+ """
222
+ group = get_unified_parallel_group()
223
+ if not group:
224
+ return qkv_tensor
225
+ world = get_unified_parallel_world_size()
226
+ orig_shape = qkv_tensor.shape
227
+ scatter_dim = qkv_tensor.dim()
228
+ bef_all2all_shape = list(orig_shape)
229
+ qkv_proj_dim = bef_all2all_shape[-1]
230
+ bef_all2all_shape = bef_all2all_shape[:-1] + [3, qkv_proj_dim // 3]
231
+ qkv_tensor = qkv_tensor.view(bef_all2all_shape)
232
+ if async_op:
233
+ return SeqAllToAll.apply(group, qkv_tensor, scatter_dim, seq_dim, async_op)
234
+ else:
235
+ qkv_tensor = SeqAllToAll.apply(group, qkv_tensor, scatter_dim, seq_dim, async_op)
236
+
237
+ if restore_shape:
238
+ out_shape = list(orig_shape)
239
+ out_shape[seq_dim] *= world
240
+ out_shape[-1] = qkv_proj_dim // world
241
+ qkv_tensor = qkv_tensor.view(out_shape)
242
+
243
+ # remove padding
244
+ if unpadded_dim_size and unpadded_dim_size % world != 0:
245
+ padding_size = qkv_tensor.size(seq_dim) - unpadded_dim_size
246
+ qkv_tensor = unpad_tensor(qkv_tensor, seq_dim, padding_size)
247
+
248
+ return qkv_tensor
249
+
250
+
251
+ def gather_seq_scatter_double_head(
252
+ qkv_tensor: Tensor,
253
+ seq_dim: int,
254
+ unpadded_dim_size: Optional[int] = None,
255
+ restore_shape: bool = True,
256
+ async_op: bool = False,
257
+ ):
258
+ """
259
+ A func to sync splited qkv tensor
260
+ qkv_tensor: the tensor we want to do alltoall with. The last dim must
261
+ be the projection_idx, which we will split into 3 part. After
262
+ spliting, the gather idx will be projecttion_idx + 1
263
+ seq_dim: gather_dim for all2all comm
264
+ restore_shape: if True, output will has the same shape length as input
265
+ """
266
+ qkv1_shape = qkv_tensor.shape
267
+ group = get_unified_parallel_group()
268
+ if not group:
269
+ return qkv_tensor
270
+ world = get_unified_parallel_world_size()
271
+ orig_shape = qkv_tensor.shape
272
+ scatter_dim = qkv_tensor.dim()
273
+ bef_all2all_shape = list(orig_shape)
274
+ qkv_proj_dim = bef_all2all_shape[-1]
275
+ bef_all2all_shape = bef_all2all_shape[:-1] + [2, qkv_proj_dim // 2]
276
+ qkv_tensor = qkv_tensor.view(bef_all2all_shape)
277
+ qkv2_shape = qkv_tensor.shape
278
+ if async_op:
279
+ return SeqAllToAll.apply(group, qkv_tensor, scatter_dim, seq_dim, async_op)
280
+ else:
281
+ qkv_tensor = SeqAllToAll.apply(group, qkv_tensor, scatter_dim, seq_dim, async_op)
282
+ qkv3_shape = qkv_tensor.shape
283
+
284
+ if restore_shape:
285
+ out_shape = list(orig_shape)
286
+ out_shape[seq_dim] *= world
287
+ out_shape[-1] = qkv_proj_dim // world
288
+ qkv_tensor = qkv_tensor.view(out_shape)
289
+ qkv4_shape = qkv_tensor.shape
290
+
291
+ # remove padding
292
+ if unpadded_dim_size and unpadded_dim_size % world != 0:
293
+ padding_size = qkv_tensor.size(seq_dim) - unpadded_dim_size
294
+ qkv_tensor = unpad_tensor(qkv_tensor, seq_dim, padding_size)
295
+ qkv5_shape = qkv_tensor.shape
296
+
297
+ return qkv_tensor
298
+
299
+
300
+ class SeqAllToAll(torch.autograd.Function):
301
+ @staticmethod
302
+ def forward(
303
+ ctx: Any,
304
+ group: dist.ProcessGroup,
305
+ local_input: Tensor,
306
+ scatter_dim: int,
307
+ gather_dim: int,
308
+ async_op: bool,
309
+ ) -> Tensor:
310
+ ctx.group = group
311
+ ctx.scatter_dim = scatter_dim
312
+ ctx.gather_dim = gather_dim
313
+ ctx.async_op = async_op
314
+ return all_to_all_tensor(local_input, scatter_dim, gather_dim, group, async_op)
315
+
316
+ @staticmethod
317
+ def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]:
318
+ if ctx.async_op:
319
+ input_t = torch.cat(grad_output[1:], dim=ctx.gather_dim).contiguous()
320
+ else:
321
+ input_t = grad_output[0]
322
+ return (
323
+ None,
324
+ all_to_all_tensor(input_t, ctx.gather_dim, ctx.scatter_dim, ctx.group, False),
325
+ None,
326
+ None,
327
+ None,
328
+ None,
329
+ )
330
+
331
+
332
+ def all_to_all_tensor(
333
+ x: Tensor,
334
+ scatter_dim: int,
335
+ gather_dim: int,
336
+ group: dist.ProcessGroup,
337
+ async_op: bool = False,
338
+ ):
339
+ if scatter_dim <= 1 and gather_dim <= 1:
340
+ return _all_to_all_single(x, scatter_dim, gather_dim, group, async_op)
341
+ else:
342
+ return _all_to_all(x, scatter_dim, gather_dim, group, async_op) # 走这里
343
+
344
+
345
+ def _all_to_all(
346
+ local_input: Tensor,
347
+ scatter_dim: int,
348
+ gather_dim: int,
349
+ group: dist.ProcessGroup,
350
+ async_op: bool = False,
351
+ ):
352
+ seq_world_size = dist.get_world_size(group)
353
+ input_list = [t.contiguous() for t in torch.tensor_split(local_input, seq_world_size, scatter_dim)]
354
+ output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)]
355
+ comm = dist.all_to_all(output_list, input_list, group=group, async_op=async_op)
356
+ if async_op:
357
+
358
+ def wait():
359
+ comm.wait()
360
+ return torch.cat(output_list, dim=gather_dim).contiguous()
361
+
362
+ return wait
363
+ return torch.cat(output_list, dim=gather_dim).contiguous()
364
+
365
+
366
+ def _all_to_all_single(x: Tensor, scatter_dim: int, gather_dim: int, group: dist.ProcessGroup, async_op: bool = False):
367
+ """
368
+ A function to do all-to-all on the first two dim
369
+ """
370
+ sp_world_size = dist.get_world_size(group)
371
+ assert scatter_dim <= 1, "scatter_dim must be 0 or 1 when using all_to_all_single!"
372
+ assert gather_dim <= 1, "gather_dim must be 0 or 1 when using all_to_all_single!"
373
+ if scatter_dim != 0:
374
+ gather_dim_bef = x.shape[gather_dim]
375
+ scatter_dim_bef = x.shape[scatter_dim]
376
+ x = (
377
+ x.reshape([gather_dim_bef, sp_world_size, scatter_dim_bef // sp_world_size] + list(x.shape[2:]))
378
+ .transpose(0, 1)
379
+ .reshape([gather_dim_bef * sp_world_size, scatter_dim_bef // sp_world_size] + list(x.shape[2:]))
380
+ .contiguous()
381
+ )
382
+
383
+ output = torch.empty_like(x)
384
+ comm = dist.all_to_all_single(output, x.contiguous(), group=group, async_op=async_op)
385
+
386
+ if async_op:
387
+
388
+ def wait():
389
+ comm.wait()
390
+ if scatter_dim == 0:
391
+ return torch.cat(output.split(x.size(0) // sp_world_size), dim=gather_dim)
392
+ else:
393
+ return output
394
+
395
+ return wait
396
+
397
+ if scatter_dim == 0:
398
+ output = torch.cat(output.split(x.size(0) // sp_world_size), dim=gather_dim)
399
+ return output
400
+
401
+
402
+ def gather_heads_scatter_seq(x: Tensor, head_dim: int, seq_dim: int) -> Tensor:
403
+ """
404
+ A func to sync attention result with alltoall in sequence parallel
405
+ """
406
+ group = get_unified_parallel_group()
407
+ if not group:
408
+ return x
409
+ dim_size = x.size(seq_dim)
410
+ sp_world = get_unified_parallel_world_size()
411
+ if dim_size % sp_world != 0:
412
+ padding_size = sp_world - (dim_size % sp_world)
413
+ x = pad_tensor(x, seq_dim, padding_size)
414
+ return SeqAllToAll.apply(group, x, seq_dim, head_dim, False)
415
+
416
+
417
+ def unpad_tensor(x: Tensor, dim: int, padding_size: int):
418
+ slc = [slice(None)] * len(x.shape)
419
+ slc[dim] = slice(0, -padding_size)
420
+ return x[slc]
421
+
422
+
423
+ class Gather(torch.autograd.Function):
424
+ @staticmethod
425
+ def forward(
426
+ ctx: Any,
427
+ group: dist.ProcessGroup,
428
+ local_input: Tensor,
429
+ dim: int,
430
+ grad_scale: Optional[bool] = False,
431
+ ) -> Tensor:
432
+ ctx.group = group
433
+ ctx.rank = dist.get_rank(group)
434
+ ctx.dim = dim
435
+ ctx.grad_scale = grad_scale
436
+ seq_world_size = dist.get_world_size(group)
437
+ ctx.seq_world_size = seq_world_size
438
+ dim_size = list(local_input.size())
439
+ split_size = dim_size[0]
440
+ ctx.part_size = dim_size[dim]
441
+ dim_size[0] = dim_size[0] * seq_world_size
442
+ output = torch.empty(dim_size, dtype=local_input.dtype, device=torch.cuda.current_device())
443
+ dist.all_gather_into_tensor(output, local_input.contiguous(), group=ctx.group)
444
+ return torch.cat(output.split(split_size), dim=dim)
445
+
446
+ @staticmethod
447
+ def backward(ctx: Any, grad_output: Tensor) -> Tuple[None, Tensor]:
448
+ if ctx.grad_scale:
449
+ grad_output = grad_output * ctx.seq_world_size
450
+ return (
451
+ None,
452
+ grad_output.split(ctx.part_size, dim=ctx.dim)[ctx.rank].contiguous(),
453
+ None,
454
+ None,
455
+ )
456
+
457
+
458
+ def slice_tensor(tensor, dim, start, end):
459
+ indices = slice(start, end)
460
+ return tensor[(slice(None),) * dim + (indices,)]
461
+
462
+
463
+ def init_model_shard_cpu_group(sharding_strategy: str, device_mesh: Optional[Tuple] = None):
464
+ """
465
+ Initialize CPU process group of model sharding.
466
+ """
467
+ global _MODEL_SHARD_CPU_GROUP
468
+ assert dist.is_initialized()
469
+ world_size = dist.get_world_size()
470
+ rank = dist.get_rank()
471
+ if device_mesh is not None:
472
+ num_shards_per_group = device_mesh[1]
473
+ elif "HYBRID" in sharding_strategy:
474
+ num_shards_per_group = min(8, world_size)
475
+ else:
476
+ num_shards_per_group = world_size
477
+ num_groups = world_size // num_shards_per_group
478
+ for i in range(num_groups):
479
+ start_rank = i * num_shards_per_group
480
+ end_rank = (i + 1) * num_shards_per_group
481
+ ranks = range(start_rank, end_rank)
482
+ cpu_group = dist.new_group(ranks, backend="gloo")
483
+ if rank in ranks:
484
+ _MODEL_SHARD_CPU_GROUP = cpu_group
humo/common/distributed/basic.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ # Codes adapted from [SeedVR]
13
+ # https://github.com/ByteDance-Seed/SeedVR/tree/main/common/distributed
14
+
15
+ """
16
+ Distributed basic functions.
17
+ """
18
+
19
+ import os
20
+ import torch
21
+ from torch import nn
22
+ import torch.distributed as dist
23
+ from torch.nn.parallel import DistributedDataParallel
24
+ from torch.distributed.fsdp._common_utils import _is_fsdp_flattened
25
+
26
+
27
+ def get_global_rank() -> int:
28
+ """
29
+ Get the global rank, the global index of the GPU.
30
+ """
31
+ return int(os.environ.get("RANK", "0"))
32
+
33
+
34
+ def get_local_rank() -> int:
35
+ """
36
+ Get the local rank, the local index of the GPU.
37
+ """
38
+ return int(os.environ.get("LOCAL_RANK", "0"))
39
+
40
+
41
+ def get_world_size() -> int:
42
+ """
43
+ Get the world size, the total amount of GPUs.
44
+ """
45
+ return int(os.environ.get("WORLD_SIZE", "1"))
46
+
47
+
48
+ def get_device() -> torch.device:
49
+ """
50
+ Get current rank device.
51
+ """
52
+ return torch.device("cuda", get_local_rank())
53
+
54
+
55
+ def barrier_if_distributed(*args, **kwargs):
56
+ """
57
+ Synchronizes all processes if under distributed context.
58
+ """
59
+ if dist.is_initialized():
60
+ return dist.barrier(*args, **kwargs)
61
+
62
+
63
+ def init_torch(cudnn_benchmark=True):
64
+ """
65
+ Common PyTorch initialization configuration.
66
+ """
67
+ torch.backends.cuda.matmul.allow_tf32 = True
68
+ torch.backends.cudnn.allow_tf32 = True
69
+ torch.backends.cudnn.benchmark = cudnn_benchmark
70
+ torch.cuda.set_device(get_local_rank())
71
+ dist.init_process_group(
72
+ backend="nccl",
73
+ rank=get_global_rank(),
74
+ world_size=get_world_size(),
75
+ )
76
+
77
+
78
+ def convert_to_ddp(module: torch.nn.Module, **kwargs) -> DistributedDataParallel:
79
+ return DistributedDataParallel(
80
+ module=module,
81
+ device_ids=[get_local_rank()],
82
+ output_device=get_local_rank(),
83
+ **kwargs,
84
+ )
85
+
86
+
87
+ def meta_param_init_fn(module: nn.Module) -> None:
88
+ """
89
+ Used for model inited onto meta device.
90
+ Init meta param/buffer with empty tensor.
91
+ We don't care numerical correctness in this func.
92
+ FSDP will sync param/buffer state from rank0 to the other ranks.
93
+ """
94
+
95
+ with torch.no_grad():
96
+ for submodule in module.modules():
97
+ for param_name, param in submodule.named_parameters(recurse=False):
98
+ if not _is_fsdp_flattened(param) and param.is_meta:
99
+ materialized_param = nn.Parameter(torch.empty_like(param, device="cpu"))
100
+ setattr(submodule, param_name, materialized_param)
101
+ for buffer_name, buffer in submodule.named_buffers(recurse=False):
102
+ if not _is_fsdp_flattened(buffer) and buffer.is_meta:
103
+ materialized_param = torch.empty_like(buffer, device="cpu")
104
+ setattr(submodule, buffer_name, materialized_param)
105
+ torch.cuda.empty_cache()
106
+
107
+
108
+ def meta_non_persistent_buffer_init_fn(module: nn.Module) -> nn.Module:
109
+ """
110
+ Materialize meta device buffers that are not persistent in state_dict.
111
+ Handles special cases like RotaryEmbedding.freqs.
112
+ """
113
+ with torch.no_grad():
114
+ for submodule in module.modules():
115
+ if hasattr(submodule, "freqs"):
116
+ freqs = getattr(submodule, "freqs")
117
+ if isinstance(freqs, torch.Tensor) and freqs.is_meta:
118
+ dim = submodule.dim
119
+ def rope_params(max_seq_len, dim, theta=10000):
120
+ assert dim % 2 == 0
121
+ freqs = torch.outer(
122
+ torch.arange(max_seq_len),
123
+ 1.0 / torch.pow(theta,
124
+ torch.arange(0, dim, 2).to(torch.float32).div(dim)))
125
+ freqs = torch.polar(torch.ones_like(freqs), freqs)
126
+ return freqs
127
+
128
+ dim = 5120 # 1536
129
+ num_heads = 40 # 12
130
+ # dim = 1536
131
+ # num_heads = 12
132
+ d = dim // num_heads
133
+ freqs_tensor = torch.cat([
134
+ rope_params(1024, d - 4 * (d // 6)),
135
+ rope_params(1024, 2 * (d // 6)),
136
+ rope_params(1024, 2 * (d // 6))
137
+ ], dim=1).to(dtype=torch.cfloat, device="cpu")
138
+
139
+ setattr(submodule, "freqs", freqs_tensor)
140
+ print(f"Successfully materialized freqs for {submodule.__class__.__name__}")
141
+
142
+ assert not any(b.is_meta for n, b in module.named_buffers())
143
+ return module
humo/common/logger.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ # Codes adapted from [SeedVR]
13
+ # https://github.com/ByteDance-Seed/SeedVR/blob/main/common/logger.py
14
+
15
+ """
16
+ Logging utility functions.
17
+ """
18
+
19
+ import logging
20
+ import sys
21
+ from typing import Optional
22
+
23
+ from common.distributed import get_global_rank, get_local_rank, get_world_size
24
+
25
+ _default_handler = logging.StreamHandler(sys.stdout)
26
+ _default_handler.setFormatter(
27
+ logging.Formatter(
28
+ "%(asctime)s "
29
+ + (f"[Rank:{get_global_rank()}]" if get_world_size() > 1 else "")
30
+ + (f"[LocalRank:{get_local_rank()}]" if get_world_size() > 1 else "")
31
+ + "[%(threadName).12s][%(name)s][%(levelname).5s] "
32
+ + "%(message)s"
33
+ )
34
+ )
35
+
36
+
37
+ def get_logger(name: Optional[str] = None) -> logging.Logger:
38
+ """
39
+ Get a logger.
40
+ """
41
+ logger = logging.getLogger(name)
42
+ logger.addHandler(_default_handler)
43
+ logger.setLevel(logging.INFO)
44
+ return logger
humo/configs/inference/generate.yaml ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __object__:
2
+ path: humo.generate
3
+ name: Generator
4
+
5
+ dit:
6
+ model:
7
+ __inherit__: humo/configs/models/Wan_14B_I2V.yaml
8
+ __object__:
9
+ path: humo.models.wan_modules.model_humo
10
+ name: WanModel
11
+ insert_audio: True
12
+ zero_vae_path: ./weights/HuMo/zero_vae_129frame.pt
13
+ zero_vae_720p_path: ./weights/HuMo/zero_vae_720p_161frame.pt
14
+ checkpoint_dir: ./weights/HuMo/HuMo-17B
15
+ compile: False
16
+ init_with_meta_device: True
17
+ gradient_checkpoint: True
18
+ fsdp:
19
+ sharding_strategy: _HYBRID_SHARD_ZERO2
20
+ sp_size: 1
21
+ dtype: bfloat16
22
+
23
+ vae:
24
+ checkpoint: ./weights/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth
25
+ vae_stride: [ 4, 8, 8 ]
26
+ scaling_factor: 0.9152
27
+ compile: False
28
+ grouping: True
29
+ use_sample: False
30
+ dtype: bfloat16
31
+
32
+ text:
33
+ t5_checkpoint: ./weights/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth
34
+ t5_tokenizer: ./weights/Wan2.1-T2V-1.3B/google/umt5-xxl
35
+ dropout: 0.1
36
+ dtype: bfloat16
37
+ fsdp:
38
+ enabled: True
39
+ sharding_strategy: HYBRID_SHARD
40
+
41
+ diffusion:
42
+ schedule:
43
+ type: lerp
44
+ T: 1000.0
45
+ sampler:
46
+ type: euler
47
+ prediction_type: v_lerp
48
+ timesteps:
49
+ training:
50
+ type: logitnormal
51
+ loc: 0.0
52
+ scale: 1.0
53
+ sampling:
54
+ type: uniform_trailing
55
+ steps: 50
56
+ shift: 5.0
57
+
58
+ audio:
59
+ vocal_separator: ./weights/HuMo/audio_separator/Kim_Vocal_2.onnx
60
+ wav2vec_model: ./weights/whisper-large-v3
61
+
62
+ generation:
63
+ mode: "TIA" # TA, TIA
64
+ extract_audio_feat: True
65
+ seed: 666666
66
+ frames: 97
67
+ fps: 25
68
+ height: 480 # 720 480
69
+ width: 832 # 1280 832
70
+ batch_size: 1
71
+ sequence_parallel: 8
72
+ output:
73
+ dir: ./output
74
+ # positive_prompt: ./examples/test_case.json
75
+ sample_neg_prompt: '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
76
+ scale_a: 5.5
77
+ scale_t: 5.0
78
+ step_change: 980
humo/configs/inference/generate_1_7B.yaml ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __object__:
2
+ path: humo.generate_1_7B
3
+ name: Generator
4
+
5
+ dit:
6
+ model:
7
+ __inherit__: humo/configs/models/Wan_1.3B.yaml
8
+ __object__:
9
+ path: humo.models.wan_modules.model_humo
10
+ name: WanModel
11
+ insert_audio: True
12
+ zero_vae_path: ./weights/HuMo/zero_vae_129frame.pt
13
+ zero_vae_720p_path: ./weights/HuMo/zero_vae_720p_161frame.pt
14
+ checkpoint_dir: ./weights/HuMo/HuMo-1.7B/ema.pth #./weights/HuMo/HuMo-17B
15
+ compile: False
16
+ init_with_meta_device: True
17
+ gradient_checkpoint: True
18
+ fsdp:
19
+ sharding_strategy: _HYBRID_SHARD_ZERO2
20
+ sp_size: 1
21
+
22
+ vae:
23
+ checkpoint: ./weights/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth
24
+ vae_stride: [ 4, 8, 8 ]
25
+ scaling_factor: 0.9152
26
+ compile: False
27
+ grouping: True
28
+ use_sample: False
29
+ dtype: bfloat16
30
+
31
+ text:
32
+ t5_checkpoint: ./weights/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth
33
+ t5_tokenizer: ./weights/Wan2.1-T2V-1.3B/google/umt5-xxl
34
+ dropout: 0.1
35
+ dtype: bfloat16
36
+ fsdp:
37
+ enabled: True
38
+ sharding_strategy: HYBRID_SHARD
39
+
40
+ diffusion:
41
+ schedule:
42
+ type: lerp
43
+ T: 1000.0
44
+ sampler:
45
+ type: euler
46
+ prediction_type: v_lerp
47
+ timesteps:
48
+ training:
49
+ type: logitnormal
50
+ loc: 0.0
51
+ scale: 1.0
52
+ sampling:
53
+ type: uniform_trailing
54
+ steps: 50
55
+ shift: 5.0
56
+
57
+ audio:
58
+ vocal_separator: ./weights/audio_separator/Kim_Vocal_2.onnx
59
+ wav2vec_model: ./weights/whisper-large-v3
60
+
61
+ generation:
62
+ mode: "TIA" # TA, TIA
63
+ extract_audio_feat: True
64
+ seed: 666666
65
+ frames: 97
66
+ fps: 25
67
+ height: 720 # 480
68
+ width: 1280 # 832
69
+ batch_size: 1
70
+ output:
71
+ dir: ./output
72
+ sample_neg_prompt: '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
73
+ scale_t: 7.5
74
+ scale_i: 4.0
75
+ scale_a: 7.5
76
+ # step_change: 980
humo/configs/models/Wan_1.3B.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __object__:
2
+ path: ???
3
+ name: ???
4
+ args: as_params
5
+
6
+ text_len: 512
7
+ patch_size: [ 1, 2, 2 ]
8
+ dim: 1536
9
+ ffn_dim: 8960
10
+ freq_dim: 256
11
+ model_type: "t2v"
12
+ num_heads: 12
13
+ num_layers: 30
14
+ window_size: [ -1, -1 ]
15
+ qk_norm: True
16
+ cross_attn_norm: True
17
+ eps: 1e-6
humo/configs/models/Wan_1.3B_I2V.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __object__:
2
+ path: ???
3
+ name: ???
4
+ args: as_params
5
+
6
+ text_len: 512
7
+ patch_size: [ 1, 2, 2 ]
8
+ dim: 1536
9
+ ffn_dim: 8960
10
+ freq_dim: 256
11
+ in_dim: 36
12
+ model_type: "i2v"
13
+ num_heads: 12
14
+ num_layers: 30
15
+ window_size: [ -1, -1 ]
16
+ qk_norm: True
17
+ cross_attn_norm: True
18
+ eps: 1e-6
humo/configs/models/Wan_14B.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __object__:
2
+ path: ???
3
+ name: ???
4
+ args: as_params
5
+
6
+ text_len: 512
7
+ patch_size: [ 1, 2, 2 ]
8
+ dim: 5120
9
+ ffn_dim: 13824
10
+ freq_dim: 256
11
+ model_type: "t2v"
12
+ num_heads: 40
13
+ num_layers: 40
14
+ window_size: [ -1, -1 ]
15
+ qk_norm: True
16
+ cross_attn_norm: True
17
+ eps: 1e-6
humo/configs/models/Wan_14B_I2V.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __object__:
2
+ path: ???
3
+ name: ???
4
+ args: as_params
5
+
6
+ text_len: 512
7
+ patch_size: [ 1, 2, 2 ]
8
+ dim: 5120
9
+ ffn_dim: 13824
10
+ freq_dim: 256
11
+ in_dim: 36
12
+ model_type: "i2v"
13
+ num_heads: 40
14
+ num_layers: 40
15
+ window_size: [ -1, -1 ]
16
+ qk_norm: True
17
+ cross_attn_norm: True
18
+ eps: 1e-6
humo/generate.py ADDED
@@ -0,0 +1,984 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ # Inference codes adapted from [SeedVR]
13
+ # https://github.com/ByteDance-Seed/SeedVR/blob/main/projects/inference_seedvr2_7b.py
14
+
15
+ import math
16
+ import os
17
+ import gc
18
+ import random
19
+ import sys
20
+ import mediapy
21
+ import numpy as np
22
+ import torch
23
+ import torch.distributed as dist
24
+ from omegaconf import DictConfig, ListConfig, OmegaConf
25
+ from einops import rearrange
26
+ from omegaconf import OmegaConf
27
+ from PIL import Image, ImageOps
28
+ from torchvision.transforms import ToTensor
29
+ from tqdm import tqdm
30
+ from torch.distributed.device_mesh import init_device_mesh
31
+ from torch.distributed.fsdp import (
32
+ BackwardPrefetch,
33
+ FullyShardedDataParallel,
34
+ MixedPrecision,
35
+ ShardingStrategy,
36
+ )
37
+ from common.distributed import (
38
+ get_device,
39
+ get_global_rank,
40
+ get_local_rank,
41
+ meta_param_init_fn,
42
+ meta_non_persistent_buffer_init_fn,
43
+ init_torch,
44
+ )
45
+ from common.distributed.advanced import (
46
+ init_unified_parallel,
47
+ get_unified_parallel_world_size,
48
+ get_sequence_parallel_rank,
49
+ init_model_shard_cpu_group,
50
+ )
51
+ from common.logger import get_logger
52
+ from common.config import create_object
53
+ from common.distributed import get_device, get_global_rank
54
+ from torchvision.transforms import Compose, Normalize, ToTensor
55
+ from humo.models.wan_modules.t5 import T5EncoderModel
56
+ from humo.models.wan_modules.vae import WanVAE
57
+ from humo.models.utils.utils import tensor_to_video, prepare_json_dataset
58
+ from contextlib import contextmanager
59
+ import torch.cuda.amp as amp
60
+ from humo.models.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
61
+ from humo.utils.audio_processor_whisper import AudioProcessor
62
+ from humo.utils.wav2vec import linear_interpolation_fps
63
+ from torchao.quantization import quantize_
64
+
65
+ import torch._dynamo as dynamo
66
+ dynamo.config.capture_scalar_outputs = True
67
+ torch.set_float32_matmul_precision("high")
68
+
69
+ import torch
70
+ import torch.nn as nn
71
+ import transformer_engine.pytorch as te
72
+
73
+ image_transform = Compose([
74
+ ToTensor(),
75
+ Normalize(mean=0.5, std=0.5),
76
+ ])
77
+
78
+ SIZE_CONFIGS = {
79
+ '720*1280': (720, 1280),
80
+ '1280*720': (1280, 720),
81
+ '480*832': (480, 832),
82
+ '832*480': (832, 480),
83
+ '1024*1024': (1024, 1024),
84
+ }
85
+
86
+ def clever_format(nums, format="%.2f"):
87
+ from typing import Iterable
88
+ if not isinstance(nums, Iterable):
89
+ nums = [nums]
90
+ clever_nums = []
91
+ for num in nums:
92
+ if num > 1e12:
93
+ clever_nums.append(format % (num / 1e12) + "T")
94
+ elif num > 1e9:
95
+ clever_nums.append(format % (num / 1e9) + "G")
96
+ elif num > 1e6:
97
+ clever_nums.append(format % (num / 1e6) + "M")
98
+ elif num > 1e3:
99
+ clever_nums.append(format % (num / 1e3) + "K")
100
+ else:
101
+ clever_nums.append(format % num + "B")
102
+
103
+ clever_nums = clever_nums[0] if len(clever_nums) == 1 else (*clever_nums,)
104
+
105
+ return clever_nums
106
+
107
+
108
+
109
+ # --- put near your imports ---
110
+ import torch
111
+ import torch.nn as nn
112
+ import contextlib
113
+ import transformer_engine.pytorch as te
114
+
115
+ # FP8 autocast compatibility for different TE versions
116
+ try:
117
+ # Preferred modern API
118
+ from transformer_engine.pytorch import fp8_autocast
119
+ try:
120
+ # Newer TE: use recipe-based API
121
+ from transformer_engine.common.recipe import DelayedScaling, Format
122
+ def make_fp8_ctx(enabled: bool = True):
123
+ if not enabled:
124
+ return contextlib.nullcontext()
125
+ fp8_recipe = DelayedScaling(fp8_format=Format.E4M3) # E4M3 format
126
+ return fp8_autocast(enabled=True, fp8_recipe=fp8_recipe)
127
+ except Exception:
128
+ # Very old variant that might still accept fp8_format directly
129
+ def make_fp8_ctx(enabled: bool = True):
130
+ # If TE doesn't have FP8Format, just no-op
131
+ if not hasattr(te, "FP8Format"):
132
+ return contextlib.nullcontext()
133
+ return te.fp8_autocast(enabled=enabled, fp8_format=te.FP8Format.E4M3)
134
+ except Exception:
135
+ # TE not present or totally incompatible — no-op
136
+ def make_fp8_ctx(enabled: bool = True):
137
+ return contextlib.nullcontext()
138
+
139
+
140
+ # TE sometimes exposes Linear at different paths; this normalizes it.
141
+ try:
142
+ TELinear = te.Linear
143
+ except AttributeError: # very old layouts
144
+ from transformer_engine.pytorch.modules.linear import Linear as TELinear # type: ignore
145
+
146
+ # --- near imports ---
147
+ import torch
148
+ import torch.nn as nn
149
+ import transformer_engine.pytorch as te
150
+
151
+ try:
152
+ TELinear = te.Linear
153
+ except AttributeError:
154
+ from transformer_engine.pytorch.modules.linear import Linear as TELinear # type: ignore
155
+
156
+ import torch
157
+ import torch.nn as nn
158
+ import transformer_engine.pytorch as te
159
+
160
+ try:
161
+ TELinear = te.Linear
162
+ except AttributeError:
163
+ from transformer_engine.pytorch.modules.linear import Linear as TELinear # type: ignore
164
+
165
+ def _default_te_allow(fullname: str, lin: nn.Linear) -> bool:
166
+ """
167
+ Allow TE only where it's shape-safe & beneficial.
168
+ Skip small/special layers (time/timestep/pos embeds, heads).
169
+ Enforce multiples of 16 for in/out features (FP8 kernel friendly).
170
+ Also skip very small projections likely to see M=1.
171
+ """
172
+ blocked_keywords = (
173
+ "time_embedding", "timestep", "time_embed",
174
+ "time_projection", "pos_embedding", "pos_embed",
175
+ "to_logits", "logits", "final_proj", "proj_out", "output_projection",
176
+ )
177
+ if any(k in fullname for k in blocked_keywords):
178
+ return False
179
+
180
+ # TE FP8 kernels like K, N divisible by 16
181
+ if lin.in_features % 16 != 0 or lin.out_features % 16 != 0:
182
+ return False
183
+
184
+ # Heuristic: avoid tiny layers; keeps attention/MLP, skips small MLPs
185
+ if lin.in_features < 512 or lin.out_features < 512:
186
+ return False
187
+
188
+ # Whitelist: only convert inside transformer blocks if you know their prefix
189
+ # This further reduces risk of catching special heads elsewhere.
190
+ allowed_context = ("blocks", "layers", "transformer", "attn", "mlp", "ffn")
191
+ if not any(tok in fullname for tok in allowed_context):
192
+ return False
193
+
194
+ return True
195
+
196
+ @torch.no_grad()
197
+ def convert_linears_to_te_fp8(module: nn.Module, allow_pred=_default_te_allow, _prefix=""):
198
+ for name, child in list(module.named_children()):
199
+ full = f"{_prefix}.{name}" if _prefix else name
200
+ convert_linears_to_te_fp8(child, allow_pred, full)
201
+
202
+ if isinstance(child, nn.Linear):
203
+ if allow_pred is not None and not allow_pred(full, child):
204
+ continue
205
+
206
+ te_lin = TELinear(
207
+ in_features=child.in_features,
208
+ out_features=child.out_features,
209
+ bias=(child.bias is not None),
210
+ params_dtype=torch.bfloat16,
211
+ ).to(child.weight.device)
212
+
213
+ te_lin.weight.copy_(child.weight.to(te_lin.weight.dtype))
214
+ if child.bias is not None:
215
+ te_lin.bias.copy_(child.bias.to(te_lin.bias.dtype))
216
+
217
+ setattr(module, name, te_lin)
218
+ return module
219
+
220
+ class Generator():
221
+ def __init__(self, config: DictConfig):
222
+ self.config = config.copy()
223
+ OmegaConf.set_readonly(self.config, True)
224
+ self.logger = get_logger(self.__class__.__name__)
225
+
226
+ # init_torch(cudnn_benchmark=False)
227
+ self.configure_models()
228
+
229
+ def entrypoint(self):
230
+
231
+ self.inference_loop()
232
+
233
+ def get_fsdp_sharding_config(self, sharding_strategy, device_mesh_config):
234
+ device_mesh = None
235
+ fsdp_strategy = ShardingStrategy[sharding_strategy]
236
+ if (
237
+ fsdp_strategy in [ShardingStrategy._HYBRID_SHARD_ZERO2, ShardingStrategy.HYBRID_SHARD]
238
+ and device_mesh_config is not None
239
+ ):
240
+ device_mesh = init_device_mesh("cuda", tuple(device_mesh_config))
241
+ return device_mesh, fsdp_strategy
242
+
243
+
244
+ def configure_models(self):
245
+ self.configure_dit_model(device="cuda")
246
+
247
+ self.dit.eval().to("cuda")
248
+ convert_linears_to_te_fp8(self.dit)
249
+
250
+ self.dit = torch.compile(self.dit, )
251
+
252
+
253
+ self.configure_vae_model(device="cuda")
254
+ if self.config.generation.get('extract_audio_feat', False):
255
+ self.configure_wav2vec(device="cpu")
256
+ self.configure_text_model(device="cuda")
257
+
258
+ # # Initialize fsdp.
259
+ # self.configure_dit_fsdp_model()
260
+ # self.configure_text_fsdp_model()
261
+
262
+ # quantize_(self.text_encoder, Int8WeightOnlyConfig())
263
+ # quantize_(self.dit, Float8DynamicActivationFloat8WeightConfig())
264
+
265
+
266
+ def configure_dit_model(self, device=get_device()):
267
+
268
+ init_unified_parallel(self.config.dit.sp_size)
269
+ self.sp_size = get_unified_parallel_world_size()
270
+
271
+ # Create DiT model on meta, then mark dtype as bfloat16 (no real allocation yet).
272
+ init_device = "meta"
273
+ with torch.device(init_device):
274
+ self.dit = create_object(self.config.dit.model)
275
+ self.dit = self.dit.to(dtype=torch.bfloat16) # or: self.dit.bfloat16()
276
+ self.logger.info(f"Load DiT model on {init_device}.")
277
+ self.dit.eval().requires_grad_(False)
278
+
279
+ # Load dit checkpoint.
280
+ path = self.config.dit.checkpoint_dir
281
+
282
+ def _cast_state_dict_to_bf16(state):
283
+ for k, v in state.items():
284
+ if isinstance(v, torch.Tensor) and v.is_floating_point():
285
+ state[k] = v.to(dtype=torch.bfloat16, copy=False)
286
+ return state
287
+
288
+ if path.endswith(".pth"):
289
+ # Load to CPU first; we’ll move the model later.
290
+ state = torch.load(path, map_location="cpu", mmap=True)
291
+ state = _cast_state_dict_to_bf16(state)
292
+ missing_keys, unexpected_keys = self.dit.load_state_dict(state, strict=False, assign=True)
293
+ self.logger.info(
294
+ f"dit loaded from {path}. Missing keys: {len(missing_keys)}, Unexpected keys: {len(unexpected_keys)}"
295
+ )
296
+ else:
297
+ from safetensors.torch import load_file
298
+ import json
299
+ def load_custom_sharded_weights(model_dir, base_name):
300
+ index_path = f"{model_dir}/{base_name}.safetensors.index.json"
301
+ with open(index_path, "r") as f:
302
+ index = json.load(f)
303
+ weight_map = index["weight_map"]
304
+ shard_files = set(weight_map.values())
305
+ state_dict = {}
306
+ for shard_file in shard_files:
307
+ shard_path = f"{model_dir}/{shard_file}"
308
+ # Load on CPU, then cast to bf16; we’ll move the whole module later.
309
+ shard_state = load_file(shard_path, device="cpu")
310
+ shard_state = {k: (v.to(dtype=torch.bfloat16, copy=False) if v.is_floating_point() else v)
311
+ for k, v in shard_state.items()}
312
+ state_dict.update(shard_state)
313
+ return state_dict
314
+
315
+ state = load_custom_sharded_weights(path, 'humo')
316
+ self.dit.load_state_dict(state, strict=False, assign=True)
317
+
318
+ self.dit = meta_non_persistent_buffer_init_fn(self.dit)
319
+
320
+ target_device = get_device() if device in [get_device(), "cuda"] else device
321
+ self.dit.to(target_device) # dtype already bf16
322
+
323
+ # Print model size.
324
+ params = sum(p.numel() for p in self.dit.parameters())
325
+ self.logger.info(
326
+ f"[RANK:{get_global_rank()}] DiT Parameters: {clever_format(params, '%.3f')}"
327
+ )
328
+
329
+
330
+ def configure_vae_model(self, device=get_device()):
331
+ self.vae_stride = self.config.vae.vae_stride
332
+ self.vae = WanVAE(
333
+ vae_pth=self.config.vae.checkpoint,
334
+ device=device)
335
+
336
+ if self.config.generation.height == 480:
337
+ self.zero_vae = torch.load(self.config.dit.zero_vae_path)
338
+ elif self.config.generation.height == 720:
339
+ self.zero_vae = torch.load(self.config.dit.zero_vae_720p_path)
340
+ else:
341
+ raise ValueError(f"Unsupported height {self.config.generation.height} for zero-vae.")
342
+
343
+ def configure_wav2vec(self, device=get_device()):
344
+ audio_separator_model_file = self.config.audio.vocal_separator
345
+ wav2vec_model_path = self.config.audio.wav2vec_model
346
+
347
+ self.audio_processor = AudioProcessor(
348
+ 16000,
349
+ 25,
350
+ wav2vec_model_path,
351
+ "all",
352
+ audio_separator_model_file,
353
+ None, # not seperate
354
+ os.path.join(self.config.generation.output.dir, "vocals"),
355
+ device=device,
356
+ )
357
+
358
+ def configure_text_model(self, device=get_device()):
359
+ self.text_encoder = T5EncoderModel(
360
+ text_len=self.config.dit.model.text_len,
361
+ dtype=torch.bfloat16,
362
+ device=device,
363
+ checkpoint_path=self.config.text.t5_checkpoint,
364
+ tokenizer_path=self.config.text.t5_tokenizer,
365
+ )
366
+
367
+
368
+ def configure_dit_fsdp_model(self):
369
+ from humo.models.wan_modules.model_humo import WanAttentionBlock
370
+
371
+ dit_blocks = (WanAttentionBlock,)
372
+
373
+ # Init model_shard_cpu_group for saving checkpoint with sharded state_dict.
374
+ init_model_shard_cpu_group(
375
+ self.config.dit.fsdp.sharding_strategy,
376
+ self.config.dit.fsdp.get("device_mesh", None),
377
+ )
378
+
379
+ # Assert that dit has wrappable blocks.
380
+ assert any(isinstance(m, dit_blocks) for m in self.dit.modules())
381
+
382
+ # Define wrap policy on all dit blocks.
383
+ def custom_auto_wrap_policy(module, recurse, *args, **kwargs):
384
+ return recurse or isinstance(module, dit_blocks)
385
+
386
+ # Configure FSDP settings.
387
+ device_mesh, fsdp_strategy = self.get_fsdp_sharding_config(
388
+ self.config.dit.fsdp.sharding_strategy,
389
+ self.config.dit.fsdp.get("device_mesh", None),
390
+ )
391
+ settings = dict(
392
+ auto_wrap_policy=custom_auto_wrap_policy,
393
+ sharding_strategy=fsdp_strategy,
394
+ backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
395
+ device_id=get_local_rank(),
396
+ use_orig_params=False,
397
+ sync_module_states=True,
398
+ forward_prefetch=True,
399
+ limit_all_gathers=False, # False for ZERO2.
400
+ mixed_precision=MixedPrecision(
401
+ param_dtype=torch.bfloat16,
402
+ reduce_dtype=torch.float32,
403
+ buffer_dtype=torch.float32,
404
+ ),
405
+ device_mesh=device_mesh,
406
+ param_init_fn=meta_param_init_fn,
407
+ )
408
+
409
+ # Apply FSDP.
410
+ self.dit = FullyShardedDataParallel(self.dit, **settings)
411
+ # self.dit.to(get_device())
412
+
413
+
414
+ def configure_text_fsdp_model(self):
415
+ # If FSDP is not enabled, put text_encoder to GPU and return.
416
+ if not self.config.text.fsdp.enabled:
417
+ self.text_encoder.to(get_device())
418
+ return
419
+
420
+ # from transformers.models.t5.modeling_t5 import T5Block
421
+ from humo.models.wan_modules.t5 import T5SelfAttention
422
+
423
+ text_blocks = (torch.nn.Embedding, T5SelfAttention)
424
+ # text_blocks_names = ("QWenBlock", "QWenModel") # QWen cannot be imported. Use str.
425
+
426
+ def custom_auto_wrap_policy(module, recurse, *args, **kwargs):
427
+ return (
428
+ recurse
429
+ or isinstance(module, text_blocks)
430
+ )
431
+
432
+ # Apply FSDP.
433
+ text_encoder_dtype = getattr(torch, self.config.text.dtype)
434
+ device_mesh, fsdp_strategy = self.get_fsdp_sharding_config(
435
+ self.config.text.fsdp.sharding_strategy,
436
+ self.config.text.fsdp.get("device_mesh", None),
437
+ )
438
+ self.text_encoder = FullyShardedDataParallel(
439
+ module=self.text_encoder,
440
+ auto_wrap_policy=custom_auto_wrap_policy,
441
+ sharding_strategy=fsdp_strategy,
442
+ backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
443
+ device_id=get_local_rank(),
444
+ use_orig_params=False,
445
+ sync_module_states=False,
446
+ forward_prefetch=True,
447
+ limit_all_gathers=True,
448
+ mixed_precision=MixedPrecision(
449
+ param_dtype=text_encoder_dtype,
450
+ reduce_dtype=text_encoder_dtype,
451
+ buffer_dtype=text_encoder_dtype,
452
+ ),
453
+ device_mesh=device_mesh,
454
+ )
455
+ self.text_encoder.to(get_device()).requires_grad_(False)
456
+
457
+
458
+ def load_image_latent_ref_id(self, path: str, size, device):
459
+ # Load size.
460
+ h, w = size[1], size[0]
461
+
462
+ # Load image.
463
+ if len(path) > 1 and not isinstance(path, str):
464
+ ref_vae_latents = []
465
+ for image_path in path:
466
+ with Image.open(image_path) as img:
467
+ img = img.convert("RGB")
468
+
469
+ # Calculate the required size to keep aspect ratio and fill the rest with padding.
470
+ img_ratio = img.width / img.height
471
+ target_ratio = w / h
472
+
473
+ if img_ratio > target_ratio: # Image is wider than target
474
+ new_width = w
475
+ new_height = int(new_width / img_ratio)
476
+ else: # Image is taller than target
477
+ new_height = h
478
+ new_width = int(new_height * img_ratio)
479
+
480
+ # img = img.resize((new_width, new_height), Image.ANTIALIAS)
481
+ img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
482
+
483
+ # Create a new image with the target size and place the resized image in the center
484
+ delta_w = w - img.size[0]
485
+ delta_h = h - img.size[1]
486
+ padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
487
+ new_img = ImageOps.expand(img, padding, fill=(255, 255, 255))
488
+
489
+ # Transform to tensor and normalize.
490
+ transform = Compose(
491
+ [
492
+ ToTensor(),
493
+ Normalize(0.5, 0.5),
494
+ ]
495
+ )
496
+ new_img = transform(new_img)
497
+ # img_vae_latent = self.vae_encode([new_img.unsqueeze(1)])[0]
498
+ img_vae_latent = self.vae.encode([new_img.unsqueeze(1)], device)
499
+ ref_vae_latents.append(img_vae_latent[0])
500
+
501
+ return [torch.cat(ref_vae_latents, dim=1)]
502
+ else:
503
+ if not isinstance(path, str):
504
+ path = path[0]
505
+ with Image.open(path) as img:
506
+ img = img.convert("RGB")
507
+
508
+ # Calculate the required size to keep aspect ratio and fill the rest with padding.
509
+ img_ratio = img.width / img.height
510
+ target_ratio = w / h
511
+
512
+ if img_ratio > target_ratio: # Image is wider than target
513
+ new_width = w
514
+ new_height = int(new_width / img_ratio)
515
+ else: # Image is taller than target
516
+ new_height = h
517
+ new_width = int(new_height * img_ratio)
518
+
519
+ # img = img.resize((new_width, new_height), Image.ANTIALIAS)
520
+ img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
521
+
522
+ # Create a new image with the target size and place the resized image in the center
523
+ delta_w = w - img.size[0]
524
+ delta_h = h - img.size[1]
525
+ padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
526
+ new_img = ImageOps.expand(img, padding, fill=(255, 255, 255))
527
+
528
+ # Transform to tensor and normalize.
529
+ transform = Compose(
530
+ [
531
+ ToTensor(),
532
+ Normalize(0.5, 0.5),
533
+ ]
534
+ )
535
+ new_img = transform(new_img)
536
+ img_vae_latent = self.vae.encode([new_img.unsqueeze(1)], device)
537
+
538
+ # Vae encode.
539
+ return img_vae_latent
540
+
541
+ def get_audio_emb_window(self, audio_emb, frame_num, frame0_idx, audio_shift=2):
542
+ zero_audio_embed = torch.zeros((audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device)
543
+ zero_audio_embed_3 = torch.zeros((3, audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device) # device=audio_emb.device
544
+ iter_ = 1 + (frame_num - 1) // 4
545
+ audio_emb_wind = []
546
+ for lt_i in range(iter_):
547
+ if lt_i == 0:
548
+ st = frame0_idx + lt_i - 2
549
+ ed = frame0_idx + lt_i + 3
550
+ wind_feat = torch.stack([
551
+ audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed
552
+ for i in range(st, ed)
553
+ ], dim=0)
554
+ wind_feat = torch.cat((zero_audio_embed_3, wind_feat), dim=0)
555
+ else:
556
+ st = frame0_idx + 1 + 4 * (lt_i - 1) - audio_shift
557
+ ed = frame0_idx + 1 + 4 * lt_i + audio_shift
558
+ wind_feat = torch.stack([
559
+ audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed
560
+ for i in range(st, ed)
561
+ ], dim=0)
562
+ audio_emb_wind.append(wind_feat)
563
+ audio_emb_wind = torch.stack(audio_emb_wind, dim=0)
564
+
565
+ return audio_emb_wind, ed - audio_shift
566
+
567
+ def audio_emb_enc(self, audio_emb, wav_enc_type="whisper"):
568
+ if wav_enc_type == "wav2vec":
569
+ feat_merge = audio_emb
570
+ elif wav_enc_type == "whisper":
571
+ feat0 = linear_interpolation_fps(audio_emb[:, :, 0: 8].mean(dim=2), 50, 25)
572
+ feat1 = linear_interpolation_fps(audio_emb[:, :, 8: 16].mean(dim=2), 50, 25)
573
+ feat2 = linear_interpolation_fps(audio_emb[:, :, 16: 24].mean(dim=2), 50, 25)
574
+ feat3 = linear_interpolation_fps(audio_emb[:, :, 24: 32].mean(dim=2), 50, 25)
575
+ feat4 = linear_interpolation_fps(audio_emb[:, :, 32], 50, 25)
576
+ feat_merge = torch.stack([feat0, feat1, feat2, feat3, feat4], dim=2)[0]
577
+ else:
578
+ raise ValueError(f"Unsupported wav_enc_type: {wav_enc_type}")
579
+
580
+ return feat_merge
581
+
582
+ def parse_output(self, output):
583
+ latent = output[0]
584
+ mask = None
585
+ return latent, mask
586
+
587
+ def forward_tia(self, latents, timestep, t, step_change, arg_tia, arg_ti, arg_i, arg_null):
588
+ pos_tia, _ = self.parse_output(self.dit(
589
+ latents, t=timestep, **arg_tia
590
+ ))
591
+ torch.cuda.empty_cache()
592
+
593
+ pos_ti, _ = self.parse_output(self.dit(
594
+ latents, t=timestep, **arg_ti
595
+ ))
596
+ torch.cuda.empty_cache()
597
+
598
+ if t > step_change:
599
+ neg, _ = self.parse_output(self.dit(
600
+ latents, t=timestep, **arg_i
601
+ )) # img included in null, same with official Wan-2.1
602
+ torch.cuda.empty_cache()
603
+
604
+ noise_pred = self.config.generation.scale_a * (pos_tia - pos_ti) + \
605
+ self.config.generation.scale_t * (pos_ti - neg) + \
606
+ neg
607
+ else:
608
+ neg, _ = self.parse_output(self.dit(
609
+ latents, t=timestep, **arg_null
610
+ )) # img not included in null
611
+ torch.cuda.empty_cache()
612
+
613
+ noise_pred = self.config.generation.scale_a * (pos_tia - pos_ti) + \
614
+ (self.config.generation.scale_t - 2.0) * (pos_ti - neg) + \
615
+ neg
616
+ return noise_pred
617
+
618
+ def forward_ti(self, latents, timestep, t, step_change, arg_ti, arg_t, arg_i, arg_null):
619
+ # Positive with text+image (no audio)
620
+ pos_ti, _ = self.parse_output(self.dit(
621
+ latents, t=timestep, **arg_ti
622
+ ))
623
+ torch.cuda.empty_cache()
624
+
625
+ # Positive with text only (no image, no audio)
626
+ pos_t, _ = self.parse_output(self.dit(
627
+ latents, t=timestep, **arg_t
628
+ ))
629
+ torch.cuda.empty_cache()
630
+
631
+ # Negative branch: before step_change, don't include image in null; after, include image (like Wan-2.1)
632
+ if t > step_change:
633
+ neg, _ = self.parse_output(self.dit(
634
+ latents, t=timestep, **arg_i
635
+ )) # img included in null
636
+ else:
637
+ neg, _ = self.parse_output(self.dit(
638
+ latents, t=timestep, **arg_null
639
+ )) # img NOT included in null
640
+ torch.cuda.empty_cache()
641
+
642
+ # Guidance blend: replace "scale_a" below with "scale_i" if you add a separate image scale in config
643
+ noise_pred = self.config.generation.scale_a * (pos_ti - pos_t) + \
644
+ self.config.generation.scale_t * (pos_t - neg) + \
645
+ neg
646
+ return noise_pred
647
+
648
+ def forward_ta(self, latents, timestep, arg_ta, arg_t, arg_null):
649
+ pos_ta, _ = self.parse_output(self.dit(
650
+ latents, t=timestep, **arg_ta
651
+ ))
652
+ torch.cuda.empty_cache()
653
+
654
+ pos_t, _ = self.parse_output(self.dit(
655
+ latents, t=timestep, **arg_t
656
+ ))
657
+ torch.cuda.empty_cache()
658
+
659
+ neg, _ = self.parse_output(self.dit(
660
+ latents, t=timestep, **arg_null
661
+ ))
662
+ torch.cuda.empty_cache()
663
+
664
+ noise_pred = self.config.generation.scale_a * (pos_ta - pos_t) + \
665
+ self.config.generation.scale_t * (pos_t - neg) + \
666
+ neg
667
+ return noise_pred
668
+
669
+ @torch.no_grad()
670
+ def inference(self,
671
+ input_prompt,
672
+ img_path,
673
+ audio_path,
674
+ size=(1280, 720),
675
+ frame_num=81,
676
+ shift=5.0,
677
+ sample_solver='unipc',
678
+ inference_mode='TIA',
679
+ sampling_steps=50,
680
+ n_prompt="",
681
+ seed=-1,
682
+ tea_cache_l1_thresh = 0.0,
683
+ device = get_device(),
684
+ ):
685
+
686
+ print("inference started")
687
+
688
+ # self.vae.model.to(device=device)
689
+ if img_path is not None:
690
+ latents_ref = self.load_image_latent_ref_id(img_path, size, device)
691
+ else:
692
+ latents_ref = [torch.zeros(16, 1, size[1]//8, size[0]//8).to(device)]
693
+
694
+ # self.vae.model.to(device="cpu")
695
+
696
+ print("vae finished")
697
+
698
+ latents_ref_neg = [torch.zeros_like(latent_ref) for latent_ref in latents_ref]
699
+
700
+ # audio
701
+ if audio_path is not None:
702
+ if self.config.generation.extract_audio_feat:
703
+ self.audio_processor.whisper.to(device=device)
704
+ audio_emb, audio_length = self.audio_processor.preprocess(audio_path)
705
+ self.audio_processor.whisper.to(device='cpu')
706
+ else:
707
+ audio_emb_path = audio_path.replace(".wav", ".pt")
708
+ audio_emb = torch.load(audio_emb_path).to(device=device)
709
+ audio_emb = self.audio_emb_enc(audio_emb, wav_enc_type="whisper")
710
+ self.logger.info("使用预先提取好的音频特征: %s", audio_emb_path)
711
+ else:
712
+ audio_emb = torch.zeros(frame_num, 5, 1280).to(device)
713
+
714
+ frame_num = frame_num if frame_num != -1 else audio_length
715
+ frame_num = 4 * ((frame_num - 1) // 4) + 1
716
+ audio_emb, _ = self.get_audio_emb_window(audio_emb, frame_num, frame0_idx=0)
717
+ zero_audio_pad = torch.zeros(latents_ref[0].shape[1], *audio_emb.shape[1:]).to(audio_emb.device)
718
+ audio_emb = torch.cat([audio_emb, zero_audio_pad], dim=0)
719
+ audio_emb = [audio_emb.to(device)]
720
+ audio_emb_neg = [torch.zeros_like(audio_emb[0])]
721
+
722
+ # preprocess
723
+ self.patch_size = self.config.dit.model.patch_size
724
+ F = frame_num
725
+ target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1 + latents_ref[0].shape[1],
726
+ size[1] // self.vae_stride[1],
727
+ size[0] // self.vae_stride[2])
728
+
729
+ seq_len = math.ceil((target_shape[2] * target_shape[3]) /
730
+ (self.patch_size[1] * self.patch_size[2]) *
731
+ target_shape[1] / self.sp_size) * self.sp_size
732
+
733
+ if n_prompt == "":
734
+ n_prompt = self.config.generation.sample_neg_prompt
735
+ seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
736
+ seed_g = torch.Generator(device=device)
737
+ seed_g.manual_seed(seed)
738
+
739
+ # self.text_encoder.model.to(device)
740
+ context = self.text_encoder([input_prompt], device)
741
+ context_null = self.text_encoder([n_prompt], device)
742
+ # self.text_encoder.model.cpu()
743
+
744
+ print("text encoder finished")
745
+
746
+ noise = [
747
+ torch.randn(
748
+ target_shape[0],
749
+ target_shape[1], # - latents_ref[0].shape[1],
750
+ target_shape[2],
751
+ target_shape[3],
752
+ dtype=torch.float32,
753
+ device=device,
754
+ generator=seed_g)
755
+ ]
756
+
757
+ @contextmanager
758
+ def noop_no_sync():
759
+ yield
760
+
761
+ no_sync = getattr(self.dit, 'no_sync', noop_no_sync)
762
+ step_change = self.config.generation.step_change # 980
763
+
764
+ # evaluation mode
765
+ with make_fp8_ctx(True), torch.autocast('cuda', dtype=torch.bfloat16), torch.no_grad(), no_sync():
766
+
767
+ if sample_solver == 'unipc':
768
+ sample_scheduler = FlowUniPCMultistepScheduler(
769
+ num_train_timesteps=1000,
770
+ shift=1,
771
+ use_dynamic_shifting=False)
772
+ sample_scheduler.set_timesteps(
773
+ sampling_steps, device=device, shift=shift)
774
+ timesteps = sample_scheduler.timesteps
775
+
776
+ # sample videos
777
+ latents = noise
778
+
779
+ msk = torch.ones(4, target_shape[1], target_shape[2], target_shape[3], device=get_device())
780
+ msk[:,:-latents_ref[0].shape[1]] = 0
781
+
782
+ zero_vae = self.zero_vae[:, :(target_shape[1]-latents_ref[0].shape[1])].to(
783
+ device=get_device(), dtype=latents_ref[0].dtype)
784
+ y_c = torch.cat([
785
+ zero_vae,
786
+ latents_ref[0]
787
+ ], dim=1)
788
+ y_c = [torch.concat([msk, y_c])]
789
+
790
+ y_null = self.zero_vae[:, :target_shape[1]].to(
791
+ device=get_device(), dtype=latents_ref[0].dtype)
792
+ y_null = [torch.concat([msk, y_null])]
793
+
794
+ tea_cache_l1_thresh = tea_cache_l1_thresh
795
+ tea_cache_model_id = "Wan2.1-T2V-14B"
796
+
797
+ arg_null = {'seq_len': seq_len, 'audio': audio_emb_neg, 'y': y_null, 'context': context_null, "tea_cache": TeaCache(sampling_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None and tea_cache_l1_thresh > 0 else None}
798
+ arg_t = {'seq_len': seq_len, 'audio': audio_emb_neg, 'y': y_null, 'context': context, "tea_cache": TeaCache(sampling_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None and tea_cache_l1_thresh > 0 else None}
799
+ arg_i = {'seq_len': seq_len, 'audio': audio_emb_neg, 'y': y_c, 'context': context_null, "tea_cache": TeaCache(sampling_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None and tea_cache_l1_thresh > 0 else None}
800
+ arg_ti = {'seq_len': seq_len, 'audio': audio_emb_neg, 'y': y_c, 'context': context, "tea_cache": TeaCache(sampling_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None and tea_cache_l1_thresh > 0 else None}
801
+ arg_ta = {'seq_len': seq_len, 'audio': audio_emb, 'y': y_null, 'context': context, "tea_cache": TeaCache(sampling_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None and tea_cache_l1_thresh > 0 else None}
802
+ arg_tia = {'seq_len': seq_len, 'audio': audio_emb, 'y': y_c, 'context': context, "tea_cache": TeaCache(sampling_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None and tea_cache_l1_thresh > 0 else None}
803
+
804
+ torch.cuda.empty_cache()
805
+ # self.dit.to(device=get_device())
806
+ for _, t in enumerate(tqdm(timesteps)):
807
+ timestep = [t]
808
+ timestep = torch.stack(timestep)
809
+
810
+ if inference_mode == "TIA":
811
+ noise_pred = self.forward_tia(latents, timestep, t, step_change,
812
+ arg_tia, arg_ti, arg_i, arg_null)
813
+ elif inference_mode == "TA":
814
+ noise_pred = self.forward_ta(latents, timestep, arg_ta, arg_t, arg_null)
815
+ elif inference_mode == "TI":
816
+ noise_pred = self.forward_ti(latents, timestep, t, step_change,
817
+ arg_ti, arg_t, arg_i, arg_null)
818
+ else:
819
+ raise ValueError(f"Unsupported generation mode: {self.config.generation.mode}")
820
+
821
+ temp_x0 = sample_scheduler.step(
822
+ noise_pred.unsqueeze(0),
823
+ t,
824
+ latents[0].unsqueeze(0),
825
+ return_dict=False,
826
+ generator=seed_g)[0]
827
+ latents = [temp_x0.squeeze(0)]
828
+
829
+ del timestep
830
+ torch.cuda.empty_cache()
831
+
832
+ x0 = latents
833
+ x0 = [x0_[:,:-latents_ref[0].shape[1]] for x0_ in x0]
834
+
835
+ # if offload_model:
836
+ # self.dit.cpu()
837
+
838
+ print("dit finished")
839
+
840
+ torch.cuda.empty_cache()
841
+ # if get_local_rank() == 0:
842
+ # self.vae.model.to(device=device)
843
+ videos = self.vae.decode(x0)
844
+ # self.vae.model.to(device="cpu")
845
+
846
+ print("vae 2 finished")
847
+
848
+ del noise, latents, noise_pred
849
+ del audio_emb, audio_emb_neg, latents_ref, latents_ref_neg, context, context_null
850
+ del x0, temp_x0
851
+ del sample_scheduler
852
+ torch.cuda.empty_cache()
853
+ gc.collect()
854
+ torch.cuda.synchronize()
855
+ if dist.is_initialized():
856
+ dist.barrier()
857
+
858
+ return videos[0] # if get_local_rank() == 0 else None
859
+
860
+
861
+ def inference_loop(self, prompt, ref_img_path, audio_path, output_dir, filename, inference_mode = "TIA", width = 832, height = 480, steps=50, frames = 97, tea_cache_l1_thresh = 0.0, seed = 0):
862
+
863
+ video = self.inference(
864
+ prompt,
865
+ ref_img_path,
866
+ audio_path,
867
+ size=SIZE_CONFIGS[f"{width}*{height}"],
868
+ frame_num=frames,
869
+ shift=self.config.diffusion.timesteps.sampling.shift,
870
+ sample_solver='unipc',
871
+ sampling_steps=steps,
872
+ inference_mode = inference_mode,
873
+ tea_cache_l1_thresh = tea_cache_l1_thresh,
874
+ seed=seed
875
+ )
876
+
877
+ torch.cuda.empty_cache()
878
+ gc.collect()
879
+
880
+ # Save samples.
881
+ if get_sequence_parallel_rank() == 0:
882
+ pathname = self.save_sample(
883
+ sample=video,
884
+ audio_path=audio_path,
885
+ output_dir = output_dir,
886
+ filename=filename,
887
+ )
888
+ self.logger.info(f"Finished {filename}, saved to {pathname}.")
889
+
890
+ del video, prompt
891
+ torch.cuda.empty_cache()
892
+ gc.collect()
893
+
894
+
895
+ def save_sample(self, *, sample: torch.Tensor, audio_path: str, output_dir: str, filename: str):
896
+ gen_config = self.config.generation
897
+ # Prepare file path.
898
+ extension = ".mp4" if sample.ndim == 4 else ".png"
899
+ filename += extension
900
+ pathname = os.path.join(output_dir, filename)
901
+ # Convert sample.
902
+ sample = sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).to("cpu", torch.uint8)
903
+ sample = rearrange(sample, "c t h w -> t h w c")
904
+ # Save file.
905
+ if sample.ndim == 4:
906
+ if audio_path is not None:
907
+ tensor_to_video(
908
+ sample.numpy(),
909
+ pathname,
910
+ audio_path,
911
+ fps=gen_config.fps)
912
+ else:
913
+ mediapy.write_video(
914
+ path=pathname,
915
+ images=sample.numpy(),
916
+ fps=gen_config.fps,
917
+ )
918
+ else:
919
+ raise ValueError
920
+ return pathname
921
+
922
+
923
+ def prepare_positive_prompts(self):
924
+ pos_prompts = self.config.generation.positive_prompt
925
+ if pos_prompts.endswith(".json"):
926
+ pos_prompts = prepare_json_dataset(pos_prompts)
927
+ else:
928
+ raise NotImplementedError
929
+ assert isinstance(pos_prompts, ListConfig)
930
+
931
+ return pos_prompts
932
+
933
+ class TeaCache:
934
+ def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
935
+ self.num_inference_steps = num_inference_steps
936
+ self.step = 0
937
+ self.accumulated_rel_l1_distance = 0
938
+ self.previous_modulated_input = None
939
+ self.rel_l1_thresh = rel_l1_thresh
940
+ self.previous_residual = None
941
+ self.previous_hidden_states = None
942
+
943
+ self.coefficients_dict = {
944
+ "Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],
945
+ "Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],
946
+ "Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01],
947
+ "Wan2.1-I2V-14B-720P": [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02],
948
+ }
949
+ if model_id not in self.coefficients_dict:
950
+ supported_model_ids = ", ".join([i for i in self.coefficients_dict])
951
+ raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).")
952
+ self.coefficients = self.coefficients_dict[model_id]
953
+
954
+ def check(self, dit, x, t_mod):
955
+ modulated_inp = t_mod.clone()
956
+ if self.step == 0 or self.step == self.num_inference_steps - 1:
957
+ should_calc = True
958
+ self.accumulated_rel_l1_distance = 0
959
+ else:
960
+ coefficients = self.coefficients
961
+ rescale_func = np.poly1d(coefficients)
962
+ self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
963
+ if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
964
+ should_calc = False
965
+ else:
966
+ should_calc = True
967
+ self.accumulated_rel_l1_distance = 0
968
+ self.previous_modulated_input = modulated_inp
969
+ self.step += 1
970
+ if self.step == self.num_inference_steps:
971
+ self.step = 0
972
+ if should_calc:
973
+ self.previous_hidden_states = x.clone()
974
+ return not should_calc
975
+
976
+ def store(self, hidden_states):
977
+ if self.previous_hidden_states is None:
978
+ return
979
+ self.previous_residual = hidden_states - self.previous_hidden_states
980
+ self.previous_hidden_states = None
981
+
982
+ def update(self, hidden_states):
983
+ hidden_states = hidden_states + self.previous_residual
984
+ return hidden_states
humo/generate_1_7B.py ADDED
@@ -0,0 +1,622 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ # Inference codes adapted from [SeedVR]
13
+ # https://github.com/ByteDance-Seed/SeedVR/blob/main/projects/inference_seedvr2_7b.py
14
+
15
+ import math
16
+ import os
17
+ import gc
18
+ import random
19
+ import sys
20
+ import mediapy
21
+ import torch
22
+ import torch.distributed as dist
23
+ from omegaconf import DictConfig, ListConfig, OmegaConf
24
+ from einops import rearrange
25
+ from omegaconf import OmegaConf
26
+ from PIL import Image, ImageOps
27
+ from torchvision.transforms import ToTensor
28
+ from tqdm import tqdm
29
+ from torch.distributed.device_mesh import init_device_mesh
30
+ from torch.distributed.fsdp import (
31
+ BackwardPrefetch,
32
+ FullyShardedDataParallel,
33
+ MixedPrecision,
34
+ ShardingStrategy,
35
+ )
36
+ from common.distributed import (
37
+ get_device,
38
+ get_global_rank,
39
+ get_local_rank,
40
+ meta_param_init_fn,
41
+ meta_non_persistent_buffer_init_fn,
42
+ init_torch,
43
+ )
44
+ from common.distributed.advanced import (
45
+ init_unified_parallel,
46
+ get_unified_parallel_world_size,
47
+ get_sequence_parallel_rank,
48
+ init_model_shard_cpu_group,
49
+ )
50
+ from common.logger import get_logger
51
+ from common.config import create_object
52
+ from common.distributed import get_device, get_global_rank
53
+ from torchvision.transforms import Compose, Normalize, ToTensor
54
+ from humo.models.wan_modules.t5 import T5EncoderModel
55
+ from humo.models.wan_modules.vae import WanVAE
56
+ from humo.models.utils.utils import tensor_to_video, prepare_json_dataset
57
+ from contextlib import contextmanager
58
+ import torch.cuda.amp as amp
59
+ from humo.models.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
60
+ from humo.utils.audio_processor_whisper import AudioProcessor
61
+ from humo.utils.wav2vec import linear_interpolation_fps
62
+
63
+
64
+ image_transform = Compose([
65
+ ToTensor(),
66
+ Normalize(mean=0.5, std=0.5),
67
+ ])
68
+
69
+ SIZE_CONFIGS = {
70
+ '720*1280': (720, 1280),
71
+ '1280*720': (1280, 720),
72
+ '480*832': (480, 832),
73
+ '832*480': (832, 480),
74
+ '1024*1024': (1024, 1024),
75
+ }
76
+
77
+ def clever_format(nums, format="%.2f"):
78
+ from typing import Iterable
79
+ if not isinstance(nums, Iterable):
80
+ nums = [nums]
81
+ clever_nums = []
82
+ for num in nums:
83
+ if num > 1e12:
84
+ clever_nums.append(format % (num / 1e12) + "T")
85
+ elif num > 1e9:
86
+ clever_nums.append(format % (num / 1e9) + "G")
87
+ elif num > 1e6:
88
+ clever_nums.append(format % (num / 1e6) + "M")
89
+ elif num > 1e3:
90
+ clever_nums.append(format % (num / 1e3) + "K")
91
+ else:
92
+ clever_nums.append(format % num + "B")
93
+
94
+ clever_nums = clever_nums[0] if len(clever_nums) == 1 else (*clever_nums,)
95
+
96
+ return clever_nums
97
+
98
+
99
+ class Generator():
100
+ def __init__(self, config: DictConfig):
101
+ self.config = config.copy()
102
+ OmegaConf.set_readonly(self.config, True)
103
+ self.logger = get_logger(self.__class__.__name__)
104
+ self.configure_models()
105
+
106
+ # init_torch(cudnn_benchmark=False)
107
+
108
+ def get_fsdp_sharding_config(self, sharding_strategy, device_mesh_config):
109
+ device_mesh = None
110
+ fsdp_strategy = ShardingStrategy[sharding_strategy]
111
+ if (
112
+ fsdp_strategy in [ShardingStrategy._HYBRID_SHARD_ZERO2, ShardingStrategy.HYBRID_SHARD]
113
+ and device_mesh_config is not None
114
+ ):
115
+ device_mesh = init_device_mesh("cuda", tuple(device_mesh_config))
116
+ return device_mesh, fsdp_strategy
117
+
118
+ def configure_models(self):
119
+ self.configure_dit_model(device="cpu")
120
+ self.configure_vae_model()
121
+ if self.config.generation.get('extract_audio_feat', False):
122
+ self.configure_wav2vec(device="cpu")
123
+ self.configure_text_model(device="cpu")
124
+
125
+ # Initialize fsdp.
126
+ self.configure_dit_fsdp_model()
127
+ self.configure_text_fsdp_model()
128
+
129
+ def configure_dit_model(self, device=get_device()):
130
+
131
+ init_unified_parallel(self.config.dit.sp_size)
132
+ self.sp_size = get_unified_parallel_world_size()
133
+
134
+ # Create dit model.
135
+ init_device = "meta"
136
+ with torch.device(init_device):
137
+ self.dit = create_object(self.config.dit.model)
138
+ self.logger.info(f"Load DiT model on {init_device}.")
139
+ self.dit.eval().requires_grad_(False)
140
+
141
+ # Load dit checkpoint.
142
+ path = self.config.dit.checkpoint_dir
143
+ if path.endswith(".pth"):
144
+ state = torch.load(path, map_location=device, mmap=True)
145
+ missing_keys, unexpected_keys = self.dit.load_state_dict(state, strict=False, assign=True)
146
+ self.logger.info(
147
+ f"dit loaded from {path}. "
148
+ f"Missing keys: {len(missing_keys)}, "
149
+ f"Unexpected keys: {len(unexpected_keys)}"
150
+ )
151
+ else:
152
+ from safetensors.torch import load_file
153
+ import json
154
+ def load_custom_sharded_weights(model_dir, base_name, device=device):
155
+ index_path = f"{model_dir}/{base_name}.safetensors.index.json"
156
+ with open(index_path, "r") as f:
157
+ index = json.load(f)
158
+ weight_map = index["weight_map"]
159
+ shard_files = set(weight_map.values())
160
+ state_dict = {}
161
+ for shard_file in shard_files:
162
+ shard_path = f"{model_dir}/{shard_file}"
163
+ shard_state = load_file(shard_path)
164
+ shard_state = {k: v.to(device) for k, v in shard_state.items()}
165
+ state_dict.update(shard_state)
166
+ return state_dict
167
+ state = load_custom_sharded_weights(path, 'humo', device)
168
+ self.dit.load_state_dict(state, strict=False, assign=True)
169
+
170
+ self.dit = meta_non_persistent_buffer_init_fn(self.dit)
171
+ if device in [get_device(), "cuda"]:
172
+ self.dit.to(get_device())
173
+
174
+ # Print model size.
175
+ params = sum(p.numel() for p in self.dit.parameters())
176
+ self.logger.info(
177
+ f"[RANK:{get_global_rank()}] DiT Parameters: {clever_format(params, '%.3f')}"
178
+ )
179
+
180
+ def configure_vae_model(self, device=get_device()):
181
+ self.vae_stride = self.config.vae.vae_stride
182
+ self.vae = WanVAE(
183
+ vae_pth=self.config.vae.checkpoint,
184
+ device=device)
185
+
186
+ if self.config.generation.height == 480:
187
+ self.zero_vae = torch.load(self.config.dit.zero_vae_path)
188
+ elif self.config.generation.height == 720:
189
+ self.zero_vae = torch.load(self.config.dit.zero_vae_720p_path)
190
+ else:
191
+ raise ValueError(f"Unsupported height {self.config.generation.height} for zero-vae.")
192
+
193
+ def configure_wav2vec(self, device=get_device()):
194
+ audio_separator_model_file = self.config.audio.vocal_separator
195
+ wav2vec_model_path = self.config.audio.wav2vec_model
196
+
197
+ self.audio_processor = AudioProcessor(
198
+ 16000,
199
+ 25,
200
+ wav2vec_model_path,
201
+ "all",
202
+ audio_separator_model_file,
203
+ None, # not seperate
204
+ os.path.join(self.config.generation.output.dir, "vocals"),
205
+ device=device,
206
+ )
207
+
208
+ def configure_text_model(self, device=get_device()):
209
+ self.text_encoder = T5EncoderModel(
210
+ text_len=self.config.dit.model.text_len,
211
+ dtype=torch.bfloat16,
212
+ device=device,
213
+ checkpoint_path=self.config.text.t5_checkpoint,
214
+ tokenizer_path=self.config.text.t5_tokenizer,
215
+ )
216
+
217
+
218
+ def configure_dit_fsdp_model(self):
219
+ self.dit.to(get_device())
220
+
221
+ return
222
+
223
+
224
+ def configure_text_fsdp_model(self):
225
+ self.text_encoder.to(get_device())
226
+
227
+ return
228
+
229
+
230
+ def load_image_latent_ref_id(self, path: str, size, device):
231
+ # Load size.
232
+ h, w = size[1], size[0]
233
+
234
+ # Load image.
235
+ if len(path) > 1 and not isinstance(path, str):
236
+ ref_vae_latents = []
237
+ for image_path in path:
238
+ with Image.open(image_path) as img:
239
+ img = img.convert("RGB")
240
+
241
+ # Calculate the required size to keep aspect ratio and fill the rest with padding.
242
+ img_ratio = img.width / img.height
243
+ target_ratio = w / h
244
+
245
+ if img_ratio > target_ratio: # Image is wider than target
246
+ new_width = w
247
+ new_height = int(new_width / img_ratio)
248
+ else: # Image is taller than target
249
+ new_height = h
250
+ new_width = int(new_height * img_ratio)
251
+
252
+ # img = img.resize((new_width, new_height), Image.ANTIALIAS)
253
+ img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
254
+
255
+ # Create a new image with the target size and place the resized image in the center
256
+ delta_w = w - img.size[0]
257
+ delta_h = h - img.size[1]
258
+ padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
259
+ new_img = ImageOps.expand(img, padding, fill=(255, 255, 255))
260
+
261
+ # Transform to tensor and normalize.
262
+ transform = Compose(
263
+ [
264
+ ToTensor(),
265
+ Normalize(0.5, 0.5),
266
+ ]
267
+ )
268
+ new_img = transform(new_img)
269
+ # img_vae_latent = self.vae_encode([new_img.unsqueeze(1)])[0]
270
+ img_vae_latent = self.vae.encode([new_img.unsqueeze(1)], device)
271
+ ref_vae_latents.append(img_vae_latent[0])
272
+
273
+ return [torch.cat(ref_vae_latents, dim=1)]
274
+ else:
275
+ if not isinstance(path, str):
276
+ path = path[0]
277
+ with Image.open(path) as img:
278
+ img = img.convert("RGB")
279
+
280
+ # Calculate the required size to keep aspect ratio and fill the rest with padding.
281
+ img_ratio = img.width / img.height
282
+ target_ratio = w / h
283
+
284
+ if img_ratio > target_ratio: # Image is wider than target
285
+ new_width = w
286
+ new_height = int(new_width / img_ratio)
287
+ else: # Image is taller than target
288
+ new_height = h
289
+ new_width = int(new_height * img_ratio)
290
+
291
+ # img = img.resize((new_width, new_height), Image.ANTIALIAS)
292
+ img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
293
+
294
+ # Create a new image with the target size and place the resized image in the center
295
+ delta_w = w - img.size[0]
296
+ delta_h = h - img.size[1]
297
+ padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
298
+ new_img = ImageOps.expand(img, padding, fill=(255, 255, 255))
299
+
300
+ # Transform to tensor and normalize.
301
+ transform = Compose(
302
+ [
303
+ ToTensor(),
304
+ Normalize(0.5, 0.5),
305
+ ]
306
+ )
307
+ new_img = transform(new_img)
308
+ img_vae_latent = self.vae.encode([new_img.unsqueeze(1)], device)
309
+
310
+ # Vae encode.
311
+ return img_vae_latent
312
+
313
+ def get_audio_emb_window(self, audio_emb, frame_num, frame0_idx, audio_shift=2):
314
+ zero_audio_embed = torch.zeros((audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device)
315
+ zero_audio_embed_3 = torch.zeros((3, audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device) # device=audio_emb.device
316
+ iter_ = 1 + (frame_num - 1) // 4
317
+ audio_emb_wind = []
318
+ for lt_i in range(iter_):
319
+ if lt_i == 0:
320
+ st = frame0_idx + lt_i - 2
321
+ ed = frame0_idx + lt_i + 3
322
+ wind_feat = torch.stack([
323
+ audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed
324
+ for i in range(st, ed)
325
+ ], dim=0)
326
+ wind_feat = torch.cat((zero_audio_embed_3, wind_feat), dim=0)
327
+ else:
328
+ st = frame0_idx + 1 + 4 * (lt_i - 1) - audio_shift
329
+ ed = frame0_idx + 1 + 4 * lt_i + audio_shift
330
+ wind_feat = torch.stack([
331
+ audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed
332
+ for i in range(st, ed)
333
+ ], dim=0)
334
+ audio_emb_wind.append(wind_feat)
335
+ audio_emb_wind = torch.stack(audio_emb_wind, dim=0)
336
+
337
+ return audio_emb_wind, ed - audio_shift
338
+
339
+ def audio_emb_enc(self, audio_emb, wav_enc_type="whisper"):
340
+ if wav_enc_type == "wav2vec":
341
+ feat_merge = audio_emb
342
+ elif wav_enc_type == "whisper":
343
+ feat0 = linear_interpolation_fps(audio_emb[:, :, 0: 8].mean(dim=2), 50, 25)
344
+ feat1 = linear_interpolation_fps(audio_emb[:, :, 8: 16].mean(dim=2), 50, 25)
345
+ feat2 = linear_interpolation_fps(audio_emb[:, :, 16: 24].mean(dim=2), 50, 25)
346
+ feat3 = linear_interpolation_fps(audio_emb[:, :, 24: 32].mean(dim=2), 50, 25)
347
+ feat4 = linear_interpolation_fps(audio_emb[:, :, 32], 50, 25)
348
+ feat_merge = torch.stack([feat0, feat1, feat2, feat3, feat4], dim=2)[0]
349
+ else:
350
+ raise ValueError(f"Unsupported wav_enc_type: {wav_enc_type}")
351
+
352
+ return feat_merge
353
+
354
+ def forward_tia(self, latents, latents_ref, latents_ref_neg, timestep, arg_t, arg_ta, arg_null):
355
+ neg = self.dit(
356
+ [torch.cat([latent[:,:-latent_ref_neg.shape[1]], latent_ref_neg], dim=1) for latent, latent_ref_neg in zip(latents, latents_ref_neg)], t=timestep, **arg_null
357
+ )[0]
358
+
359
+ pos_t = self.dit(
360
+ [torch.cat([latent[:,:-latent_ref_neg.shape[1]], latent_ref_neg], dim=1) for latent, latent_ref_neg in zip(latents, latents_ref_neg)], t=timestep, **arg_t
361
+ )[0]
362
+ pos_ta = self.dit(
363
+ [torch.cat([latent[:,:-latent_ref_neg.shape[1]], latent_ref_neg], dim=1) for latent, latent_ref_neg in zip(latents, latents_ref_neg)], t=timestep, **arg_ta
364
+ )[0]
365
+ pos_tia = self.dit(
366
+ [torch.cat([latent[:,:-latent_ref.shape[1]], latent_ref], dim=1) for latent, latent_ref in zip(latents, latents_ref)], t=timestep, **arg_ta
367
+ )[0]
368
+
369
+ noise_pred = self.config.generation.scale_i * (pos_tia - pos_ta) + \
370
+ self.config.generation.scale_a * (pos_ta - pos_t) + \
371
+ self.config.generation.scale_t * (pos_t - neg) + \
372
+ neg
373
+
374
+ return noise_pred
375
+
376
+ def forward_ta(self, latents, latents_ref_neg, timestep, arg_t, arg_ta, arg_null):
377
+ neg = self.dit(
378
+ [torch.cat([latent[:,:-latent_ref_neg.shape[1]], latent_ref_neg], dim=1) for latent, latent_ref_neg in zip(latents, latents_ref_neg)], t=timestep, **arg_null
379
+ )[0]
380
+
381
+ pos_t = self.dit(
382
+ [torch.cat([latent[:,:-latent_ref_neg.shape[1]], latent_ref_neg], dim=1) for latent, latent_ref_neg in zip(latents, latents_ref_neg)], t=timestep, **arg_t
383
+ )[0]
384
+ pos_ta = self.dit(
385
+ [torch.cat([latent[:,:-latent_ref_neg.shape[1]], latent_ref_neg], dim=1) for latent, latent_ref_neg in zip(latents, latents_ref_neg)], t=timestep, **arg_ta
386
+ )[0]
387
+
388
+ noise_pred = self.config.generation.scale_a * (pos_ta - pos_t) + \
389
+ self.config.generation.scale_t * (pos_t - neg) + \
390
+ neg
391
+
392
+ return noise_pred
393
+
394
+
395
+ @torch.no_grad()
396
+ def inference(self,
397
+ input_prompt,
398
+ img_path,
399
+ audio_path,
400
+ size=(1280, 720),
401
+ frame_num=81,
402
+ shift=5.0,
403
+ sample_solver='unipc',
404
+ sampling_steps=50,
405
+ n_prompt="",
406
+ seed=-1,
407
+ offload_model=True,
408
+ device = get_device(),
409
+ ):
410
+
411
+ self.vae.model.to(device=device)
412
+ if img_path is not None:
413
+ latents_ref = self.load_image_latent_ref_id(img_path, size, device)
414
+ else:
415
+ latents_ref = [torch.zeros(16, 1, size[1]//8, size[0]//8).to(device)]
416
+
417
+ self.vae.model.to(device="cpu")
418
+ latents_ref_neg = [torch.zeros_like(latent_ref) for latent_ref in latents_ref]
419
+
420
+ # audio
421
+ if audio_path is not None:
422
+ if self.config.generation.extract_audio_feat:
423
+ self.audio_processor.whisper.to(device=device)
424
+ audio_emb, audio_length = self.audio_processor.preprocess(audio_path)
425
+ self.audio_processor.whisper.to(device='cpu')
426
+ else:
427
+ audio_emb_path = audio_path.replace(".wav", ".pt")
428
+ audio_emb = torch.load(audio_emb_path).to(device=device)
429
+ audio_emb = self.audio_emb_enc(audio_emb, wav_enc_type="whisper")
430
+ self.logger.info("使用预先提取好的音频特征: %s", audio_emb_path)
431
+ else:
432
+ audio_emb = torch.zeros(frame_num, 5, 1280).to(device)
433
+
434
+ frame_num = frame_num if frame_num != -1 else audio_length
435
+ frame_num = 4 * ((frame_num - 1) // 4) + 1
436
+ audio_emb, _ = self.get_audio_emb_window(audio_emb, frame_num, frame0_idx=0)
437
+ zero_audio_pad = torch.zeros(latents_ref[0].shape[1], *audio_emb.shape[1:]).to(audio_emb.device)
438
+ audio_emb = torch.cat([audio_emb, zero_audio_pad], dim=0)
439
+ audio_emb = [audio_emb.to(device)]
440
+ audio_emb_neg = [torch.zeros_like(audio_emb[0])]
441
+
442
+ # preprocess
443
+ self.patch_size = self.config.dit.model.patch_size
444
+ F = frame_num
445
+ target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1 + latents_ref[0].shape[1],
446
+ size[1] // self.vae_stride[1],
447
+ size[0] // self.vae_stride[2])
448
+
449
+ seq_len = math.ceil((target_shape[2] * target_shape[3]) /
450
+ (self.patch_size[1] * self.patch_size[2]) *
451
+ target_shape[1] / self.sp_size) * self.sp_size
452
+
453
+ if n_prompt == "":
454
+ n_prompt = self.config.generation.sample_neg_prompt
455
+ seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
456
+ seed_g = torch.Generator(device=device)
457
+ seed_g.manual_seed(seed)
458
+
459
+ self.text_encoder.model.to(device)
460
+ context = self.text_encoder([input_prompt], device)
461
+ context_null = self.text_encoder([n_prompt], device)
462
+ self.text_encoder.model.cpu()
463
+
464
+ noise = [
465
+ torch.randn(
466
+ target_shape[0],
467
+ target_shape[1], # - latents_ref[0].shape[1],
468
+ target_shape[2],
469
+ target_shape[3],
470
+ dtype=torch.float32,
471
+ device=device,
472
+ generator=seed_g)
473
+ ]
474
+
475
+ @contextmanager
476
+ def noop_no_sync():
477
+ yield
478
+
479
+ no_sync = getattr(self.dit, 'no_sync', noop_no_sync)
480
+ # step_change = self.config.generation.step_change # 980
481
+
482
+ # evaluation mode
483
+ with amp.autocast(dtype=torch.bfloat16), torch.no_grad(), no_sync():
484
+
485
+ if sample_solver == 'unipc':
486
+ sample_scheduler = FlowUniPCMultistepScheduler(
487
+ num_train_timesteps=1000,
488
+ shift=1,
489
+ use_dynamic_shifting=False)
490
+ sample_scheduler.set_timesteps(
491
+ sampling_steps, device=device, shift=shift)
492
+ timesteps = sample_scheduler.timesteps
493
+
494
+ # sample videos
495
+ latents = noise
496
+
497
+ # referene image在下面的输入中手动指定, 不在arg中指定
498
+ arg_ta = {'context': context, 'seq_len': seq_len, 'audio': audio_emb}
499
+ arg_t = {'context': context, 'seq_len': seq_len, 'audio': audio_emb_neg}
500
+ arg_null = {'context': context_null, 'seq_len': seq_len, 'audio': audio_emb_neg}
501
+
502
+ torch.cuda.empty_cache()
503
+ self.dit.to(device=get_device())
504
+ for _, t in enumerate(tqdm(timesteps)):
505
+ timestep = [t]
506
+ timestep = torch.stack(timestep)
507
+
508
+ if self.config.generation.mode == "TIA":
509
+ noise_pred = self.forward_tia(latents, latents_ref, latents_ref_neg, timestep, arg_t, arg_ta, arg_null)
510
+ elif self.config.generation.mode == "TA":
511
+ noise_pred = self.forward_ta(latents, latents_ref_neg, timestep, arg_t, arg_ta, arg_null)
512
+ else:
513
+ raise ValueError(f"Unsupported generation mode: {self.config.generation.mode}")
514
+
515
+ temp_x0 = sample_scheduler.step(
516
+ noise_pred.unsqueeze(0),
517
+ t,
518
+ latents[0].unsqueeze(0),
519
+ return_dict=False,
520
+ generator=seed_g)[0]
521
+ latents = [temp_x0.squeeze(0)]
522
+
523
+ del timestep
524
+ torch.cuda.empty_cache()
525
+
526
+ x0 = latents
527
+ x0 = [x0_[:,:-latents_ref[0].shape[1]] for x0_ in x0]
528
+
529
+ # if offload_model:
530
+ self.dit.cpu()
531
+ torch.cuda.empty_cache()
532
+ # if get_local_rank() == 0:
533
+ self.vae.model.to(device=device)
534
+ videos = self.vae.decode(x0)
535
+ self.vae.model.to(device="cpu")
536
+
537
+ del noise, latents, noise_pred
538
+ del audio_emb, audio_emb_neg, latents_ref, latents_ref_neg, context, context_null
539
+ del x0, temp_x0
540
+ del sample_scheduler
541
+ torch.cuda.empty_cache()
542
+ gc.collect()
543
+ torch.cuda.synchronize()
544
+ if dist.is_initialized():
545
+ dist.barrier()
546
+
547
+ return videos[0] # if get_local_rank() == 0 else None
548
+
549
+
550
+ def inference_loop(self, prompt, ref_img_path, audio_path, output_dir, filename, width = 832, height = 480, steps=50, frames = 97, seed = 0):
551
+ print(f'ref_img_path:{ref_img_path}')
552
+
553
+ video = self.inference(
554
+ prompt,
555
+ ref_img_path,
556
+ audio_path,
557
+ size=SIZE_CONFIGS[f"{width}*{height}"],
558
+ frame_num=frames,
559
+ shift=self.config.diffusion.timesteps.sampling.shift,
560
+ sample_solver='unipc',
561
+ sampling_steps=steps,
562
+ seed=seed,
563
+ offload_model=False,
564
+ )
565
+
566
+ torch.cuda.empty_cache()
567
+ gc.collect()
568
+
569
+
570
+ # Save samples.
571
+ if get_sequence_parallel_rank() == 0:
572
+ pathname = self.save_sample(
573
+ sample=video,
574
+ audio_path=audio_path,
575
+ output_dir = output_dir,
576
+ filename=filename,
577
+ )
578
+ self.logger.info(f"Finished {filename}, saved to {pathname}.")
579
+
580
+ del video, prompt
581
+ torch.cuda.empty_cache()
582
+ gc.collect()
583
+
584
+
585
+
586
+ def save_sample(self, *, sample: torch.Tensor, audio_path: str, output_dir: str, filename: str):
587
+ gen_config = self.config.generation
588
+ # Prepare file path.
589
+ extension = ".mp4" if sample.ndim == 4 else ".png"
590
+ filename += extension
591
+ pathname = os.path.join(output_dir, filename)
592
+ # Convert sample.
593
+ sample = sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).to("cpu", torch.uint8)
594
+ sample = rearrange(sample, "c t h w -> t h w c")
595
+ # Save file.
596
+ if sample.ndim == 4:
597
+ if audio_path is not None:
598
+ tensor_to_video(
599
+ sample.numpy(),
600
+ pathname,
601
+ audio_path,
602
+ fps=gen_config.fps)
603
+ else:
604
+ mediapy.write_video(
605
+ path=pathname,
606
+ images=sample.numpy(),
607
+ fps=gen_config.fps,
608
+ )
609
+ else:
610
+ raise ValueError
611
+ return pathname
612
+
613
+
614
+ def prepare_positive_prompts(self):
615
+ pos_prompts = self.config.generation.positive_prompt
616
+ if pos_prompts.endswith(".json"):
617
+ pos_prompts = prepare_json_dataset(pos_prompts)
618
+ else:
619
+ raise NotImplementedError
620
+ assert isinstance(pos_prompts, ListConfig)
621
+
622
+ return pos_prompts
humo/models/audio/audio_proj.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange
3
+ from torch import nn
4
+ from einops import rearrange
5
+
6
+ class WanRMSNorm(nn.Module):
7
+
8
+ def __init__(self, dim, eps=1e-5):
9
+ super().__init__()
10
+ self.dim = dim
11
+ self.eps = eps
12
+ self.weight = nn.Parameter(torch.ones(dim))
13
+
14
+ def forward(self, x):
15
+ r"""
16
+ Args:
17
+ x(Tensor): Shape [B, L, C]
18
+ """
19
+ return self._norm(x.float()).type_as(x) * self.weight
20
+
21
+ def _norm(self, x):
22
+ return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
23
+
24
+
25
+ class DummyAdapterLayer(nn.Module):
26
+ def __init__(self, layer):
27
+ super().__init__()
28
+ self.layer = layer
29
+
30
+ def forward(self, *args, **kwargs):
31
+ return self.layer(*args, **kwargs)
32
+
33
+
34
+ class AudioProjModel(nn.Module):
35
+ def __init__(
36
+ self,
37
+ seq_len=5,
38
+ blocks=13, # add a new parameter blocks
39
+ channels=768, # add a new parameter channels
40
+ intermediate_dim=512,
41
+ output_dim=1536,
42
+ context_tokens=16,
43
+ ):
44
+ super().__init__()
45
+
46
+ self.seq_len = seq_len
47
+ self.blocks = blocks
48
+ self.channels = channels
49
+ self.input_dim = seq_len * blocks * channels # update input_dim to be the product of blocks and channels.
50
+ self.intermediate_dim = intermediate_dim
51
+ self.context_tokens = context_tokens
52
+ self.output_dim = output_dim
53
+
54
+ # define multiple linear layers
55
+ self.audio_proj_glob_1 = DummyAdapterLayer(nn.Linear(self.input_dim, intermediate_dim))
56
+ self.audio_proj_glob_2 = DummyAdapterLayer(nn.Linear(intermediate_dim, intermediate_dim))
57
+ self.audio_proj_glob_3 = DummyAdapterLayer(nn.Linear(intermediate_dim, context_tokens * output_dim))
58
+
59
+ self.audio_proj_glob_norm = DummyAdapterLayer(nn.LayerNorm(output_dim))
60
+
61
+ self.initialize_weights()
62
+
63
+ def initialize_weights(self):
64
+ # Initialize transformer layers:
65
+ def _basic_init(module):
66
+ if isinstance(module, nn.Linear):
67
+ torch.nn.init.xavier_uniform_(module.weight)
68
+ if module.bias is not None:
69
+ nn.init.constant_(module.bias, 0)
70
+
71
+ self.apply(_basic_init)
72
+
73
+ def forward(self, audio_embeds):
74
+ video_length = audio_embeds.shape[1]
75
+ audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c")
76
+ batch_size, window_size, blocks, channels = audio_embeds.shape
77
+ audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels)
78
+
79
+ audio_embeds = torch.relu(self.audio_proj_glob_1(audio_embeds))
80
+ audio_embeds = torch.relu(self.audio_proj_glob_2(audio_embeds))
81
+
82
+ context_tokens = self.audio_proj_glob_3(audio_embeds).reshape(batch_size, self.context_tokens, self.output_dim)
83
+
84
+ context_tokens = self.audio_proj_glob_norm(context_tokens)
85
+ context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length)
86
+
87
+ return context_tokens
humo/models/distributed/__init__.py ADDED
File without changes
humo/models/distributed/dit_ulysses_sequence_parallel.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ import torch
13
+ from einops import rearrange
14
+ from common.distributed import get_device
15
+
16
+ from common.distributed.advanced import (
17
+ get_unified_parallel_world_size,
18
+ get_unified_parallel_group,
19
+ pad_tensor,
20
+ Slice,
21
+ gather_outputs,
22
+ gather_seq_scatter_heads_qkv,
23
+ gather_seq_scatter_double_head,
24
+ gather_heads_scatter_seq,
25
+ unpad_tensor
26
+ )
27
+ from humo.models.wan_modules.attention import flash_attention
28
+ from humo.models.wan_modules.model_humo import rope_apply, sinusoidal_embedding_1d
29
+
30
+
31
+ def ulysses_dit_forward(
32
+ self,
33
+ x,
34
+ t,
35
+ context,
36
+ seq_len,
37
+ audio=None,
38
+ y=None
39
+ ):
40
+ """
41
+ x: A list of videos each with shape [C, T, H, W].
42
+ t: [B].
43
+ context: A list of text embeddings each with shape [L, C].
44
+ """
45
+ if self.model_type == 'i2v':
46
+ # assert clip_fea is not None and y is not None
47
+ assert y is not None
48
+ # params
49
+ device = self.patch_embedding.weight.device
50
+ if self.freqs.device != device:
51
+ self.freqs = self.freqs.to(device)
52
+
53
+ if y is not None:
54
+ x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
55
+
56
+ # embeddings
57
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
58
+ grid_sizes = torch.stack(
59
+ [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
60
+ x = [u.flatten(2).transpose(1, 2) for u in x]
61
+ seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long, device=device)
62
+
63
+ assert seq_lens.max() <= seq_len
64
+ x = torch.cat([
65
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
66
+ for u in x
67
+ ])
68
+
69
+ # time embeddings
70
+ with torch.amp.autocast('cuda', dtype=torch.float32):
71
+ e = self.time_embedding(
72
+ sinusoidal_embedding_1d(self.freq_dim, t).float()).float()
73
+ e0 = self.time_projection(e).unflatten(1, (6, self.dim)).float()
74
+ assert e.dtype == torch.float32 and e0.dtype == torch.float32
75
+
76
+ # context
77
+ context_lens = None
78
+ context = self.text_embedding(
79
+ torch.stack([
80
+ torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
81
+ for u in context
82
+ ]))
83
+
84
+ if self.insert_audio:
85
+ audio = [self.audio_proj(au.unsqueeze(0)).permute(0, 3, 1, 2) for au in audio]
86
+
87
+ audio_seq_len = torch.tensor(max([au.shape[2] for au in audio]) * audio[0].shape[3], device=get_device())
88
+ audio = [au.flatten(2).transpose(1, 2) for au in audio] # [1, t*32, 1536]
89
+ audio_seq_lens = torch.tensor([au.size(1) for au in audio], dtype=torch.long, device=device)
90
+ audio = torch.cat([
91
+ torch.cat([au, au.new_zeros(1, audio_seq_len - au.size(1), au.size(2))],
92
+ dim=1) for au in audio
93
+ ])
94
+ else:
95
+ audio = None
96
+ audio_seq_len = None
97
+ audio_seq_lens = None
98
+
99
+ # ulysses support
100
+ sp_world = get_unified_parallel_world_size()
101
+ group = get_unified_parallel_group()
102
+ if seq_len % sp_world:
103
+ padding_size = sp_world - (seq_len % sp_world)
104
+ x = pad_tensor(x, dim=1, padding_size=padding_size)
105
+
106
+ if self.insert_audio:
107
+ audio_padding_size = sp_world - (audio_seq_len % sp_world)
108
+ audio = pad_tensor(audio, dim=1, padding_size=audio_padding_size)
109
+
110
+ x = Slice.apply(group, x, 1, True)
111
+
112
+ if self.insert_audio:
113
+ audio = Slice.apply(group, audio, 1, True)
114
+
115
+ # arguments
116
+ kwargs = dict(
117
+ e=e0,
118
+ seq_lens=seq_lens,
119
+ grid_sizes=grid_sizes,
120
+ freqs=self.freqs,
121
+ context=context,
122
+ context_lens=context_lens,
123
+ audio=audio,
124
+ audio_seq_len=audio_seq_len)
125
+
126
+ for block in self.blocks:
127
+ x = block(x, **kwargs)
128
+
129
+ # head
130
+ x = self.head(x, e)
131
+
132
+ # ulysses support
133
+ x = gather_outputs(x, gather_dim=1, padding_dim=1, unpad_dim_size=seq_len, scale_grad=True)
134
+
135
+ # unpatchify
136
+ x = self.unpatchify(x, grid_sizes)
137
+ return [u.float() for u in x]
138
+
139
+
140
+ def ulysses_attn_forward(
141
+ self,
142
+ x,
143
+ seq_lens,
144
+ grid_sizes,
145
+ freqs,
146
+ dtype=torch.bfloat16
147
+ ):
148
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
149
+ seq_len = seq_lens.max()
150
+ half_dtypes = (torch.float16, torch.bfloat16)
151
+
152
+ def half(x):
153
+ return x if x.dtype in half_dtypes else x.to(dtype)
154
+
155
+ # query, key, value function
156
+ def qkv_fn(x):
157
+ q = self.norm_q(self.q(x))
158
+ k = self.norm_k(self.k(x))
159
+ v = self.v(x)
160
+ return q, k, v
161
+
162
+ q, k, v = qkv_fn(x)
163
+
164
+ # ulysses support
165
+ sp_size = get_unified_parallel_world_size()
166
+ if n % sp_size:
167
+ pad_size = sp_size - (n % sp_size)
168
+ pad_size = pad_size * d
169
+ pad_inner_dim = n * d + pad_size
170
+ q = pad_tensor(q, dim=2, padding_size=pad_size)
171
+ k = pad_tensor(k, dim=2, padding_size=pad_size)
172
+ v = pad_tensor(v, dim=2, padding_size=pad_size)
173
+ else:
174
+ pad_inner_dim = n * d
175
+
176
+ qkv = torch.cat([q, k, v], dim=2)
177
+ qkv = gather_seq_scatter_heads_qkv(qkv, seq_dim=1, unpadded_dim_size=seq_len)
178
+ q, k, v = qkv.split(pad_inner_dim // sp_size, dim=2)
179
+
180
+ pad_n = pad_inner_dim // d
181
+ pad_split_n = pad_n // sp_size
182
+ q = q.view(b, seq_len, pad_split_n, d)
183
+ k = k.view(b, seq_len, pad_split_n, d)
184
+ v = v.view(b, seq_len, pad_split_n, d)
185
+
186
+ q = rope_apply(q, grid_sizes, freqs)
187
+ k = rope_apply(k, grid_sizes, freqs)
188
+
189
+ x = flash_attention(
190
+ q=half(q),
191
+ k=half(k),
192
+ v=half(v),
193
+ k_lens=seq_lens,
194
+ window_size=self.window_size
195
+ )
196
+
197
+ # ulysses support
198
+ x = x.flatten(2)
199
+ x = gather_heads_scatter_seq(x, head_dim=2, seq_dim=1)
200
+ if n % sp_size:
201
+ x = unpad_tensor(x, dim=2, unpad_dim_size=seq_len)
202
+
203
+ x = self.o(x)
204
+ return x
205
+
206
+
207
+ def ulysses_audio_cross_attn_forward(
208
+ self,
209
+ x,
210
+ audio,
211
+ seq_lens,
212
+ grid_sizes,
213
+ freqs,
214
+ audio_seq_len,
215
+ dtype=torch.bfloat16
216
+ ):
217
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
218
+ seq_len = seq_lens.max()
219
+
220
+ q = self.norm_q(self.q(x))
221
+ k = self.norm_k(self.k(audio))
222
+ v = self.v(audio)
223
+
224
+ # ulysses support
225
+ sp_size = get_unified_parallel_world_size()
226
+ if n % sp_size:
227
+ pad_size = sp_size - (n % sp_size)
228
+ pad_size = pad_size * d
229
+ pad_inner_dim = n * d + pad_size
230
+ q = pad_tensor(q, dim=2, padding_size=pad_size)
231
+ k = pad_tensor(k, dim=2, padding_size=pad_size)
232
+ v = pad_tensor(v, dim=2, padding_size=pad_size)
233
+ else:
234
+ pad_inner_dim = n * d
235
+
236
+ qq = torch.cat([q, q], dim=2)
237
+ kv = torch.cat([k, v], dim=2)
238
+ qq = gather_seq_scatter_double_head(qq, seq_dim=1, unpadded_dim_size=seq_len)
239
+ kv = gather_seq_scatter_double_head(kv, seq_dim=1, unpadded_dim_size=audio_seq_len)
240
+ q, _ = qq.split(pad_inner_dim // sp_size, dim=2)
241
+ k, v = kv.split(pad_inner_dim // sp_size, dim=2)
242
+
243
+ pad_n = pad_inner_dim // d
244
+ pad_split_n = pad_n // sp_size
245
+ q = q.view(b, seq_len, pad_split_n, d)
246
+ k = k.view(b, audio_seq_len, pad_split_n, d)
247
+ v = v.view(b, audio_seq_len, pad_split_n, d)
248
+
249
+ hlen_wlen = int(grid_sizes[0][1] * grid_sizes[0][2])
250
+ assert hlen_wlen == 1560 or hlen_wlen == 3600
251
+ q = q.reshape(-1, hlen_wlen, pad_split_n, d)
252
+ k = k.reshape(-1, 16, pad_split_n, d)
253
+ v = v.reshape(-1, 16, pad_split_n, d)
254
+
255
+ x = flash_attention(
256
+ q=q,
257
+ k=k,
258
+ v=v,
259
+ k_lens=None,
260
+ )
261
+ x = x.view(b, -1, pad_split_n, d)
262
+
263
+ # ulysses support
264
+ x = x.flatten(2)
265
+ x = gather_heads_scatter_seq(x, head_dim=2, seq_dim=1)
266
+ if n % sp_size:
267
+ x = unpad_tensor(x, dim=2, unpad_dim_size=seq_len)
268
+
269
+ x = self.o(x)
270
+ return x
humo/models/distributed/fsdp.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from functools import partial
13
+
14
+ import torch
15
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
16
+ from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
17
+ from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
18
+
19
+
20
+ def shard_model(
21
+ model,
22
+ device_id,
23
+ param_dtype=torch.bfloat16,
24
+ reduce_dtype=torch.float32,
25
+ buffer_dtype=torch.float32,
26
+ process_group=None,
27
+ sharding_strategy=ShardingStrategy.FULL_SHARD,
28
+ sync_module_states=True,
29
+ ):
30
+ model = FSDP(
31
+ module=model,
32
+ process_group=process_group,
33
+ sharding_strategy=sharding_strategy,
34
+ auto_wrap_policy=partial(
35
+ lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks),
36
+ mixed_precision=MixedPrecision(
37
+ param_dtype=param_dtype,
38
+ reduce_dtype=reduce_dtype,
39
+ buffer_dtype=buffer_dtype),
40
+ device_id=device_id,
41
+ sync_module_states=sync_module_states)
42
+ return model
humo/models/text/encoder.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass
3
+ from typing import List, Optional, Union
4
+ import torch
5
+ from omegaconf import DictConfig, OmegaConf
6
+ from torch import nn
7
+ from transformers import (
8
+ AutoModelForCausalLM,
9
+ AutoTokenizer,
10
+ CLIPTextModel,
11
+ CLIPTokenizerFast,
12
+ T5EncoderModel,
13
+ T5TokenizerFast,
14
+ )
15
+ from transformers.tokenization_utils_base import BatchEncoding
16
+
17
+ from common.fs import download_and_extract
18
+ from common.logger import get_logger
19
+
20
+ logger = get_logger(__name__)
21
+
22
+ MODEL_TYPES = {
23
+ "clip": (CLIPTokenizerFast, CLIPTextModel),
24
+ "t5": (T5TokenizerFast, T5EncoderModel),
25
+ "llm14b": (AutoTokenizer, AutoModelForCausalLM),
26
+ }
27
+
28
+
29
+ @dataclass
30
+ class TextEncoderOutput:
31
+ embeddings: Union[torch.FloatTensor, List[torch.FloatTensor]]
32
+ masks: Union[torch.BoolTensor, List[torch.BoolTensor]]
33
+ pooled: Optional[Union[torch.FloatTensor, List[torch.FloatTensor]]]
34
+
35
+
36
+ class TextEncoder(nn.Module):
37
+ def __init__(self, config: DictConfig):
38
+ super().__init__()
39
+ self.config = config
40
+ self.tokenizers = []
41
+ self.models = nn.ModuleList([])
42
+
43
+ # Disable tokenizer parallelism since we already use distributed training.
44
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
45
+
46
+ for model in config.models:
47
+ tokenizer_cls, model_cls = MODEL_TYPES[model.type]
48
+ path = download_and_extract(model.path)
49
+ max_length = model.max_length
50
+
51
+ if model.type == "llm14b":
52
+ tokenizer = tokenizer_cls.from_pretrained(
53
+ path,
54
+ model_max_length=max_length,
55
+ use_fast=False,
56
+ trust_remote_code=True,
57
+ padding_side="right",
58
+ truncation_side="right",
59
+ add_eod_token=True,
60
+ )
61
+ tokenizer.add_special_tokens({"pad_token": "<|endoftext|>"})
62
+ model = model_cls.from_pretrained(path, trust_remote_code=True, bf16=True)
63
+ else:
64
+ tokenizer = tokenizer_cls.from_pretrained(path, model_max_length=max_length)
65
+ model = model_cls.from_pretrained(path, torch_dtype=torch.bfloat16)
66
+ self.tokenizers.append(tokenizer)
67
+ self.models.append(model)
68
+
69
+ def forward(self, text: Union[str, List[str]]) -> TextEncoderOutput:
70
+ embeddings, masks, pooled = [], [], []
71
+
72
+ for encoder_config, tokenizer, model in zip(
73
+ self.config.models, self.tokenizers, self.models
74
+ ):
75
+ if encoder_config.type == "llm14b":
76
+ use_mask = encoder_config.get("mask", True)
77
+ tokens = tokenizer(
78
+ text,
79
+ return_tensors="pt",
80
+ padding="max_length",
81
+ truncation=True,
82
+ ).to(model.device)
83
+ token_ids = tokens["input_ids"]
84
+ attention_mask = tokens["attention_mask"]
85
+ num_tokens = attention_mask.sum(dim=1)
86
+ range_ids = torch.arange(len(token_ids), device=token_ids.device, dtype=torch.long)
87
+ token_ids[range_ids, num_tokens.clamp(max=token_ids.size(1) - 1)] = (
88
+ tokenizer.pad_token_id
89
+ )
90
+ attention_mask[range_ids, num_tokens.clamp(max=token_ids.size(1) - 1)] = 1
91
+ tokens = BatchEncoding({"input_ids": token_ids, "attention_mask": attention_mask})
92
+ output = model.transformer(
93
+ input_ids=tokens.input_ids,
94
+ attention_mask=attention_mask if use_mask else None,
95
+ output_hidden_states=False,
96
+ use_cache=False,
97
+ )
98
+ emb = output.last_hidden_state # batch_size, num_tokens, feat_dim
99
+ # emb *= tokens.attention_mask.unsqueeze(-1)
100
+
101
+ embeddings.append(emb)
102
+ masks.append(
103
+ tokens.attention_mask.bool() if use_mask else tokens.attention_mask > -1
104
+ )
105
+
106
+ else:
107
+ # Tokenizer
108
+ tokens = tokenizer(
109
+ text=text,
110
+ truncation=True,
111
+ padding="max_length",
112
+ return_tensors="pt",
113
+ )
114
+
115
+ # Encoder
116
+ use_mask = encoder_config.get("mask", True)
117
+ input_ids = tokens.input_ids.to(model.device)
118
+ attention_mask = tokens.attention_mask.to(model.device)
119
+ output = model(
120
+ input_ids=input_ids,
121
+ attention_mask=attention_mask if use_mask else None,
122
+ output_hidden_states=True,
123
+ )
124
+
125
+ # Save embeddings from the defined layer.
126
+ layer = encoder_config.get("layer", "last")
127
+ if layer == "last":
128
+ embeddings.append(output.last_hidden_state)
129
+ elif layer == "penultimate":
130
+ embeddings.append(model.text_model.final_layer_norm(output.hidden_states[-2]))
131
+ elif layer == "penultimate_nonorm":
132
+ embeddings.append(output.hidden_states[-2])
133
+ else:
134
+ raise NotImplementedError(f"Unknown layer type: {layer}.")
135
+
136
+ # Save masks
137
+ masks.append(attention_mask.bool() if use_mask else attention_mask > -1)
138
+
139
+ # Save pooled output if available.
140
+ if hasattr(output, "pooler_output"):
141
+ pooled.append(output.pooler_output)
142
+
143
+ output_config = self.config.get("output") or OmegaConf.create()
144
+ embedding_output_type = output_config.get("embedding_and_mask", "undefined")
145
+ pooled_output_type = output_config.get("pooled", "undefined")
146
+
147
+ # Select or merge embeddings and mask if needed.
148
+ if embedding_output_type == "undefined" and len(self.models) == 1:
149
+ embeddings = embeddings[0]
150
+ masks = masks[0]
151
+ elif embedding_output_type == "channel_concat":
152
+ embeddings = torch.cat(embeddings, dim=-1)
153
+ masks = sum(masks).bool()
154
+ elif embedding_output_type == "last":
155
+ embeddings = embeddings[-1]
156
+ masks = masks[-1]
157
+ else:
158
+ raise NotImplementedError(f"output.embedding_and_mask: {embedding_output_type}")
159
+
160
+ # Select or merge pooled output if needed.
161
+ if pooled_output_type == "undefined":
162
+ pooled = None
163
+ elif pooled_output_type == "channel_concat":
164
+ pooled = torch.cat(pooled, dim=-1)
165
+ elif pooled_output_type == "first":
166
+ pooled = pooled[0]
167
+ elif pooled_output_type == "last":
168
+ pooled = pooled[-1]
169
+ else:
170
+ raise NotImplementedError(f"output.pooled: {pooled_output_type}")
171
+
172
+ # Return final results.
173
+ return TextEncoderOutput(embeddings, masks, pooled)
humo/models/utils/fm_solvers.py ADDED
@@ -0,0 +1,857 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
2
+ # Convert dpm solver for flow matching
3
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
4
+
5
+ import inspect
6
+ import math
7
+ from typing import List, Optional, Tuple, Union
8
+
9
+ import numpy as np
10
+ import torch
11
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
12
+ from diffusers.schedulers.scheduling_utils import (KarrasDiffusionSchedulers,
13
+ SchedulerMixin,
14
+ SchedulerOutput)
15
+ from diffusers.utils import deprecate, is_scipy_available
16
+ from diffusers.utils.torch_utils import randn_tensor
17
+
18
+ if is_scipy_available():
19
+ pass
20
+
21
+
22
+ def get_sampling_sigmas(sampling_steps, shift):
23
+ sigma = np.linspace(1, 0, sampling_steps + 1)[:sampling_steps]
24
+ sigma = (shift * sigma / (1 + (shift - 1) * sigma))
25
+
26
+ return sigma
27
+
28
+
29
+ def retrieve_timesteps(
30
+ scheduler,
31
+ num_inference_steps=None,
32
+ device=None,
33
+ timesteps=None,
34
+ sigmas=None,
35
+ **kwargs,
36
+ ):
37
+ if timesteps is not None and sigmas is not None:
38
+ raise ValueError(
39
+ "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
40
+ )
41
+ if timesteps is not None:
42
+ accepts_timesteps = "timesteps" in set(
43
+ inspect.signature(scheduler.set_timesteps).parameters.keys())
44
+ if not accepts_timesteps:
45
+ raise ValueError(
46
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
47
+ f" timestep schedules. Please check whether you are using the correct scheduler."
48
+ )
49
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
50
+ timesteps = scheduler.timesteps
51
+ num_inference_steps = len(timesteps)
52
+ elif sigmas is not None:
53
+ accept_sigmas = "sigmas" in set(
54
+ inspect.signature(scheduler.set_timesteps).parameters.keys())
55
+ if not accept_sigmas:
56
+ raise ValueError(
57
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
58
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
59
+ )
60
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
61
+ timesteps = scheduler.timesteps
62
+ num_inference_steps = len(timesteps)
63
+ else:
64
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
65
+ timesteps = scheduler.timesteps
66
+ return timesteps, num_inference_steps
67
+
68
+
69
+ class FlowDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
70
+ """
71
+ `FlowDPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs.
72
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
73
+ methods the library implements for all schedulers such as loading and saving.
74
+ Args:
75
+ num_train_timesteps (`int`, defaults to 1000):
76
+ The number of diffusion steps to train the model. This determines the resolution of the diffusion process.
77
+ solver_order (`int`, defaults to 2):
78
+ The DPMSolver order which can be `1`, `2`, or `3`. It is recommended to use `solver_order=2` for guided
79
+ sampling, and `solver_order=3` for unconditional sampling. This affects the number of model outputs stored
80
+ and used in multistep updates.
81
+ prediction_type (`str`, defaults to "flow_prediction"):
82
+ Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts
83
+ the flow of the diffusion process.
84
+ shift (`float`, *optional*, defaults to 1.0):
85
+ A factor used to adjust the sigmas in the noise schedule. It modifies the step sizes during the sampling
86
+ process.
87
+ use_dynamic_shifting (`bool`, defaults to `False`):
88
+ Whether to apply dynamic shifting to the timesteps based on image resolution. If `True`, the shifting is
89
+ applied on the fly.
90
+ thresholding (`bool`, defaults to `False`):
91
+ Whether to use the "dynamic thresholding" method. This method adjusts the predicted sample to prevent
92
+ saturation and improve photorealism.
93
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
94
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
95
+ sample_max_value (`float`, defaults to 1.0):
96
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
97
+ `algorithm_type="dpmsolver++"`.
98
+ algorithm_type (`str`, defaults to `dpmsolver++`):
99
+ Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The
100
+ `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927)
101
+ paper, and the `dpmsolver++` type implements the algorithms in the
102
+ [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or
103
+ `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
104
+ solver_type (`str`, defaults to `midpoint`):
105
+ Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
106
+ sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
107
+ lower_order_final (`bool`, defaults to `True`):
108
+ Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
109
+ stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
110
+ euler_at_final (`bool`, defaults to `False`):
111
+ Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail
112
+ richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference
113
+ steps, but sometimes may result in blurring.
114
+ final_sigmas_type (`str`, *optional*, defaults to "zero"):
115
+ The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
116
+ sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
117
+ lambda_min_clipped (`float`, defaults to `-inf`):
118
+ Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
119
+ cosine (`squaredcos_cap_v2`) noise schedule.
120
+ variance_type (`str`, *optional*):
121
+ Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output
122
+ contains the predicted Gaussian variance.
123
+ """
124
+
125
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
126
+ order = 1
127
+
128
+ @register_to_config
129
+ def __init__(
130
+ self,
131
+ num_train_timesteps: int = 1000,
132
+ solver_order: int = 2,
133
+ prediction_type: str = "flow_prediction",
134
+ shift: Optional[float] = 1.0,
135
+ use_dynamic_shifting=False,
136
+ thresholding: bool = False,
137
+ dynamic_thresholding_ratio: float = 0.995,
138
+ sample_max_value: float = 1.0,
139
+ algorithm_type: str = "dpmsolver++",
140
+ solver_type: str = "midpoint",
141
+ lower_order_final: bool = True,
142
+ euler_at_final: bool = False,
143
+ final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
144
+ lambda_min_clipped: float = -float("inf"),
145
+ variance_type: Optional[str] = None,
146
+ invert_sigmas: bool = False,
147
+ ):
148
+ if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
149
+ deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
150
+ deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0",
151
+ deprecation_message)
152
+
153
+ # settings for DPM-Solver
154
+ if algorithm_type not in [
155
+ "dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"
156
+ ]:
157
+ if algorithm_type == "deis":
158
+ self.register_to_config(algorithm_type="dpmsolver++")
159
+ else:
160
+ raise NotImplementedError(
161
+ f"{algorithm_type} is not implemented for {self.__class__}")
162
+
163
+ if solver_type not in ["midpoint", "heun"]:
164
+ if solver_type in ["logrho", "bh1", "bh2"]:
165
+ self.register_to_config(solver_type="midpoint")
166
+ else:
167
+ raise NotImplementedError(
168
+ f"{solver_type} is not implemented for {self.__class__}")
169
+
170
+ if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"
171
+ ] and final_sigmas_type == "zero":
172
+ raise ValueError(
173
+ f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead."
174
+ )
175
+
176
+ # setable values
177
+ self.num_inference_steps = None
178
+ alphas = np.linspace(1, 1 / num_train_timesteps,
179
+ num_train_timesteps)[::-1].copy()
180
+ sigmas = 1.0 - alphas
181
+ sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32)
182
+
183
+ if not use_dynamic_shifting:
184
+ # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
185
+ sigmas = shift * sigmas / (1 +
186
+ (shift - 1) * sigmas) # pyright: ignore
187
+
188
+ self.sigmas = sigmas
189
+ self.timesteps = sigmas * num_train_timesteps
190
+
191
+ self.model_outputs = [None] * solver_order
192
+ self.lower_order_nums = 0
193
+ self._step_index = None
194
+ self._begin_index = None
195
+
196
+ # self.sigmas = self.sigmas.to(
197
+ # "cpu") # to avoid too much CPU/GPU communication
198
+ self.sigma_min = self.sigmas[-1].item()
199
+ self.sigma_max = self.sigmas[0].item()
200
+
201
+ @property
202
+ def step_index(self):
203
+ """
204
+ The index counter for current timestep. It will increase 1 after each scheduler step.
205
+ """
206
+ return self._step_index
207
+
208
+ @property
209
+ def begin_index(self):
210
+ """
211
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
212
+ """
213
+ return self._begin_index
214
+
215
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
216
+ def set_begin_index(self, begin_index: int = 0):
217
+ """
218
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
219
+ Args:
220
+ begin_index (`int`):
221
+ The begin index for the scheduler.
222
+ """
223
+ self._begin_index = begin_index
224
+
225
+ # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps
226
+ def set_timesteps(
227
+ self,
228
+ num_inference_steps: Union[int, None] = None,
229
+ device: Union[str, torch.device] = None,
230
+ sigmas: Optional[List[float]] = None,
231
+ mu: Optional[Union[float, None]] = None,
232
+ shift: Optional[Union[float, None]] = None,
233
+ ):
234
+ """
235
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
236
+ Args:
237
+ num_inference_steps (`int`):
238
+ Total number of the spacing of the time steps.
239
+ device (`str` or `torch.device`, *optional*):
240
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
241
+ """
242
+
243
+ if self.config.use_dynamic_shifting and mu is None:
244
+ raise ValueError(
245
+ " you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`"
246
+ )
247
+
248
+ if sigmas is None:
249
+ sigmas = np.linspace(self.sigma_max, self.sigma_min,
250
+ num_inference_steps +
251
+ 1).copy()[:-1] # pyright: ignore
252
+
253
+ if self.config.use_dynamic_shifting:
254
+ sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore
255
+ else:
256
+ if shift is None:
257
+ shift = self.config.shift
258
+ sigmas = shift * sigmas / (1 +
259
+ (shift - 1) * sigmas) # pyright: ignore
260
+
261
+ if self.config.final_sigmas_type == "sigma_min":
262
+ sigma_last = ((1 - self.alphas_cumprod[0]) /
263
+ self.alphas_cumprod[0])**0.5
264
+ elif self.config.final_sigmas_type == "zero":
265
+ sigma_last = 0
266
+ else:
267
+ raise ValueError(
268
+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
269
+ )
270
+
271
+ timesteps = sigmas * self.config.num_train_timesteps
272
+ sigmas = np.concatenate([sigmas, [sigma_last]
273
+ ]).astype(np.float32) # pyright: ignore
274
+
275
+ self.sigmas = torch.from_numpy(sigmas)
276
+ self.timesteps = torch.from_numpy(timesteps).to(
277
+ device=device, dtype=torch.int64)
278
+
279
+ self.num_inference_steps = len(timesteps)
280
+
281
+ self.model_outputs = [
282
+ None,
283
+ ] * self.config.solver_order
284
+ self.lower_order_nums = 0
285
+
286
+ self._step_index = None
287
+ self._begin_index = None
288
+ # self.sigmas = self.sigmas.to(
289
+ # "cpu") # to avoid too much CPU/GPU communication
290
+
291
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
292
+ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
293
+ """
294
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
295
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
296
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
297
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
298
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
299
+ https://arxiv.org/abs/2205.11487
300
+ """
301
+ dtype = sample.dtype
302
+ batch_size, channels, *remaining_dims = sample.shape
303
+
304
+ if dtype not in (torch.float32, torch.float64):
305
+ sample = sample.float(
306
+ ) # upcast for quantile calculation, and clamp not implemented for cpu half
307
+
308
+ # Flatten sample for doing quantile calculation along each image
309
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
310
+
311
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
312
+
313
+ s = torch.quantile(
314
+ abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
315
+ s = torch.clamp(
316
+ s, min=1, max=self.config.sample_max_value
317
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
318
+ s = s.unsqueeze(
319
+ 1) # (batch_size, 1) because clamp will broadcast along dim=0
320
+ sample = torch.clamp(
321
+ sample, -s, s
322
+ ) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
323
+
324
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
325
+ sample = sample.to(dtype)
326
+
327
+ return sample
328
+
329
+ # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t
330
+ def _sigma_to_t(self, sigma):
331
+ return sigma * self.config.num_train_timesteps
332
+
333
+ def _sigma_to_alpha_sigma_t(self, sigma):
334
+ return 1 - sigma, sigma
335
+
336
+ # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps
337
+ def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
338
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma)
339
+
340
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output
341
+ def convert_model_output(
342
+ self,
343
+ model_output: torch.Tensor,
344
+ *args,
345
+ sample: torch.Tensor = None,
346
+ **kwargs,
347
+ ) -> torch.Tensor:
348
+ """
349
+ Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
350
+ designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
351
+ integral of the data prediction model.
352
+ <Tip>
353
+ The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
354
+ prediction and data prediction models.
355
+ </Tip>
356
+ Args:
357
+ model_output (`torch.Tensor`):
358
+ The direct output from the learned diffusion model.
359
+ sample (`torch.Tensor`):
360
+ A current instance of a sample created by the diffusion process.
361
+ Returns:
362
+ `torch.Tensor`:
363
+ The converted model output.
364
+ """
365
+ timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
366
+ if sample is None:
367
+ if len(args) > 1:
368
+ sample = args[1]
369
+ else:
370
+ raise ValueError(
371
+ "missing `sample` as a required keyward argument")
372
+ if timestep is not None:
373
+ deprecate(
374
+ "timesteps",
375
+ "1.0.0",
376
+ "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
377
+ )
378
+
379
+ # DPM-Solver++ needs to solve an integral of the data prediction model.
380
+ if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
381
+ if self.config.prediction_type == "flow_prediction":
382
+ sigma_t = self.sigmas[self.step_index]
383
+ x0_pred = sample - sigma_t * model_output
384
+ else:
385
+ raise ValueError(
386
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
387
+ " `v_prediction`, or `flow_prediction` for the FlowDPMSolverMultistepScheduler."
388
+ )
389
+
390
+ if self.config.thresholding:
391
+ x0_pred = self._threshold_sample(x0_pred)
392
+
393
+ return x0_pred
394
+
395
+ # DPM-Solver needs to solve an integral of the noise prediction model.
396
+ elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
397
+ if self.config.prediction_type == "flow_prediction":
398
+ sigma_t = self.sigmas[self.step_index]
399
+ epsilon = sample - (1 - sigma_t) * model_output
400
+ else:
401
+ raise ValueError(
402
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
403
+ " `v_prediction` or `flow_prediction` for the FlowDPMSolverMultistepScheduler."
404
+ )
405
+
406
+ if self.config.thresholding:
407
+ sigma_t = self.sigmas[self.step_index]
408
+ x0_pred = sample - sigma_t * model_output
409
+ x0_pred = self._threshold_sample(x0_pred)
410
+ epsilon = model_output + x0_pred
411
+
412
+ return epsilon
413
+
414
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.dpm_solver_first_order_update
415
+ def dpm_solver_first_order_update(
416
+ self,
417
+ model_output: torch.Tensor,
418
+ *args,
419
+ sample: torch.Tensor = None,
420
+ noise: Optional[torch.Tensor] = None,
421
+ **kwargs,
422
+ ) -> torch.Tensor:
423
+ """
424
+ One step for the first-order DPMSolver (equivalent to DDIM).
425
+ Args:
426
+ model_output (`torch.Tensor`):
427
+ The direct output from the learned diffusion model.
428
+ sample (`torch.Tensor`):
429
+ A current instance of a sample created by the diffusion process.
430
+ Returns:
431
+ `torch.Tensor`:
432
+ The sample tensor at the previous timestep.
433
+ """
434
+ timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
435
+ prev_timestep = args[1] if len(args) > 1 else kwargs.pop(
436
+ "prev_timestep", None)
437
+ if sample is None:
438
+ if len(args) > 2:
439
+ sample = args[2]
440
+ else:
441
+ raise ValueError(
442
+ " missing `sample` as a required keyward argument")
443
+ if timestep is not None:
444
+ deprecate(
445
+ "timesteps",
446
+ "1.0.0",
447
+ "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
448
+ )
449
+
450
+ if prev_timestep is not None:
451
+ deprecate(
452
+ "prev_timestep",
453
+ "1.0.0",
454
+ "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
455
+ )
456
+
457
+ sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[
458
+ self.step_index] # pyright: ignore
459
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
460
+ alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
461
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
462
+ lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
463
+
464
+ h = lambda_t - lambda_s
465
+ if self.config.algorithm_type == "dpmsolver++":
466
+ x_t = (sigma_t /
467
+ sigma_s) * sample - (alpha_t *
468
+ (torch.exp(-h) - 1.0)) * model_output
469
+ elif self.config.algorithm_type == "dpmsolver":
470
+ x_t = (alpha_t /
471
+ alpha_s) * sample - (sigma_t *
472
+ (torch.exp(h) - 1.0)) * model_output
473
+ elif self.config.algorithm_type == "sde-dpmsolver++":
474
+ assert noise is not None
475
+ x_t = ((sigma_t / sigma_s * torch.exp(-h)) * sample +
476
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output +
477
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise)
478
+ elif self.config.algorithm_type == "sde-dpmsolver":
479
+ assert noise is not None
480
+ x_t = ((alpha_t / alpha_s) * sample - 2.0 *
481
+ (sigma_t * (torch.exp(h) - 1.0)) * model_output +
482
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise)
483
+ return x_t # pyright: ignore
484
+
485
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_second_order_update
486
+ def multistep_dpm_solver_second_order_update(
487
+ self,
488
+ model_output_list: List[torch.Tensor],
489
+ *args,
490
+ sample: torch.Tensor = None,
491
+ noise: Optional[torch.Tensor] = None,
492
+ **kwargs,
493
+ ) -> torch.Tensor:
494
+ """
495
+ One step for the second-order multistep DPMSolver.
496
+ Args:
497
+ model_output_list (`List[torch.Tensor]`):
498
+ The direct outputs from learned diffusion model at current and latter timesteps.
499
+ sample (`torch.Tensor`):
500
+ A current instance of a sample created by the diffusion process.
501
+ Returns:
502
+ `torch.Tensor`:
503
+ The sample tensor at the previous timestep.
504
+ """
505
+ timestep_list = args[0] if len(args) > 0 else kwargs.pop(
506
+ "timestep_list", None)
507
+ prev_timestep = args[1] if len(args) > 1 else kwargs.pop(
508
+ "prev_timestep", None)
509
+ if sample is None:
510
+ if len(args) > 2:
511
+ sample = args[2]
512
+ else:
513
+ raise ValueError(
514
+ " missing `sample` as a required keyward argument")
515
+ if timestep_list is not None:
516
+ deprecate(
517
+ "timestep_list",
518
+ "1.0.0",
519
+ "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
520
+ )
521
+
522
+ if prev_timestep is not None:
523
+ deprecate(
524
+ "prev_timestep",
525
+ "1.0.0",
526
+ "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
527
+ )
528
+
529
+ sigma_t, sigma_s0, sigma_s1 = (
530
+ self.sigmas[self.step_index + 1], # pyright: ignore
531
+ self.sigmas[self.step_index],
532
+ self.sigmas[self.step_index - 1], # pyright: ignore
533
+ )
534
+
535
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
536
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
537
+ alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
538
+
539
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
540
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
541
+ lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
542
+
543
+ m0, m1 = model_output_list[-1], model_output_list[-2]
544
+
545
+ h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
546
+ r0 = h_0 / h
547
+ D0, D1 = m0, (1.0 / r0) * (m0 - m1)
548
+ if self.config.algorithm_type == "dpmsolver++":
549
+ # See https://arxiv.org/abs/2211.01095 for detailed derivations
550
+ if self.config.solver_type == "midpoint":
551
+ x_t = ((sigma_t / sigma_s0) * sample -
552
+ (alpha_t * (torch.exp(-h) - 1.0)) * D0 - 0.5 *
553
+ (alpha_t * (torch.exp(-h) - 1.0)) * D1)
554
+ elif self.config.solver_type == "heun":
555
+ x_t = ((sigma_t / sigma_s0) * sample -
556
+ (alpha_t * (torch.exp(-h) - 1.0)) * D0 +
557
+ (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1)
558
+ elif self.config.algorithm_type == "dpmsolver":
559
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
560
+ if self.config.solver_type == "midpoint":
561
+ x_t = ((alpha_t / alpha_s0) * sample -
562
+ (sigma_t * (torch.exp(h) - 1.0)) * D0 - 0.5 *
563
+ (sigma_t * (torch.exp(h) - 1.0)) * D1)
564
+ elif self.config.solver_type == "heun":
565
+ x_t = ((alpha_t / alpha_s0) * sample -
566
+ (sigma_t * (torch.exp(h) - 1.0)) * D0 -
567
+ (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1)
568
+ elif self.config.algorithm_type == "sde-dpmsolver++":
569
+ assert noise is not None
570
+ if self.config.solver_type == "midpoint":
571
+ x_t = ((sigma_t / sigma_s0 * torch.exp(-h)) * sample +
572
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + 0.5 *
573
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 +
574
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise)
575
+ elif self.config.solver_type == "heun":
576
+ x_t = ((sigma_t / sigma_s0 * torch.exp(-h)) * sample +
577
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 +
578
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) /
579
+ (-2.0 * h) + 1.0)) * D1 +
580
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise)
581
+ elif self.config.algorithm_type == "sde-dpmsolver":
582
+ assert noise is not None
583
+ if self.config.solver_type == "midpoint":
584
+ x_t = ((alpha_t / alpha_s0) * sample - 2.0 *
585
+ (sigma_t * (torch.exp(h) - 1.0)) * D0 -
586
+ (sigma_t * (torch.exp(h) - 1.0)) * D1 +
587
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise)
588
+ elif self.config.solver_type == "heun":
589
+ x_t = ((alpha_t / alpha_s0) * sample - 2.0 *
590
+ (sigma_t * (torch.exp(h) - 1.0)) * D0 - 2.0 *
591
+ (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 +
592
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise)
593
+ return x_t # pyright: ignore
594
+
595
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_third_order_update
596
+ def multistep_dpm_solver_third_order_update(
597
+ self,
598
+ model_output_list: List[torch.Tensor],
599
+ *args,
600
+ sample: torch.Tensor = None,
601
+ **kwargs,
602
+ ) -> torch.Tensor:
603
+ """
604
+ One step for the third-order multistep DPMSolver.
605
+ Args:
606
+ model_output_list (`List[torch.Tensor]`):
607
+ The direct outputs from learned diffusion model at current and latter timesteps.
608
+ sample (`torch.Tensor`):
609
+ A current instance of a sample created by diffusion process.
610
+ Returns:
611
+ `torch.Tensor`:
612
+ The sample tensor at the previous timestep.
613
+ """
614
+
615
+ timestep_list = args[0] if len(args) > 0 else kwargs.pop(
616
+ "timestep_list", None)
617
+ prev_timestep = args[1] if len(args) > 1 else kwargs.pop(
618
+ "prev_timestep", None)
619
+ if sample is None:
620
+ if len(args) > 2:
621
+ sample = args[2]
622
+ else:
623
+ raise ValueError(
624
+ " missing`sample` as a required keyward argument")
625
+ if timestep_list is not None:
626
+ deprecate(
627
+ "timestep_list",
628
+ "1.0.0",
629
+ "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
630
+ )
631
+
632
+ if prev_timestep is not None:
633
+ deprecate(
634
+ "prev_timestep",
635
+ "1.0.0",
636
+ "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
637
+ )
638
+
639
+ sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
640
+ self.sigmas[self.step_index + 1], # pyright: ignore
641
+ self.sigmas[self.step_index],
642
+ self.sigmas[self.step_index - 1], # pyright: ignore
643
+ self.sigmas[self.step_index - 2], # pyright: ignore
644
+ )
645
+
646
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
647
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
648
+ alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
649
+ alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
650
+
651
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
652
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
653
+ lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
654
+ lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2)
655
+
656
+ m0, m1, m2 = model_output_list[-1], model_output_list[
657
+ -2], model_output_list[-3]
658
+
659
+ h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
660
+ r0, r1 = h_0 / h, h_1 / h
661
+ D0 = m0
662
+ D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)
663
+ D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
664
+ D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
665
+ if self.config.algorithm_type == "dpmsolver++":
666
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
667
+ x_t = ((sigma_t / sigma_s0) * sample -
668
+ (alpha_t * (torch.exp(-h) - 1.0)) * D0 +
669
+ (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 -
670
+ (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2)
671
+ elif self.config.algorithm_type == "dpmsolver":
672
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
673
+ x_t = ((alpha_t / alpha_s0) * sample - (sigma_t *
674
+ (torch.exp(h) - 1.0)) * D0 -
675
+ (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 -
676
+ (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2)
677
+ return x_t # pyright: ignore
678
+
679
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
680
+ if schedule_timesteps is None:
681
+ schedule_timesteps = self.timesteps
682
+
683
+ indices = (schedule_timesteps == timestep).nonzero()
684
+
685
+ # The sigma index that is taken for the **very** first `step`
686
+ # is always the second index (or the last index if there is only 1)
687
+ # This way we can ensure we don't accidentally skip a sigma in
688
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
689
+ pos = 1 if len(indices) > 1 else 0
690
+
691
+ return indices[pos].item()
692
+
693
+ def _init_step_index(self, timestep):
694
+ """
695
+ Initialize the step_index counter for the scheduler.
696
+ """
697
+
698
+ if self.begin_index is None:
699
+ if isinstance(timestep, torch.Tensor):
700
+ timestep = timestep.to(self.timesteps.device)
701
+ self._step_index = self.index_for_timestep(timestep)
702
+ else:
703
+ self._step_index = self._begin_index
704
+
705
+ # Modified from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.step
706
+ def step(
707
+ self,
708
+ model_output: torch.Tensor,
709
+ timestep: Union[int, torch.Tensor],
710
+ sample: torch.Tensor,
711
+ generator=None,
712
+ variance_noise: Optional[torch.Tensor] = None,
713
+ return_dict: bool = True,
714
+ ) -> Union[SchedulerOutput, Tuple]:
715
+ """
716
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
717
+ the multistep DPMSolver.
718
+ Args:
719
+ model_output (`torch.Tensor`):
720
+ The direct output from learned diffusion model.
721
+ timestep (`int`):
722
+ The current discrete timestep in the diffusion chain.
723
+ sample (`torch.Tensor`):
724
+ A current instance of a sample created by the diffusion process.
725
+ generator (`torch.Generator`, *optional*):
726
+ A random number generator.
727
+ variance_noise (`torch.Tensor`):
728
+ Alternative to generating noise with `generator` by directly providing the noise for the variance
729
+ itself. Useful for methods such as [`LEdits++`].
730
+ return_dict (`bool`):
731
+ Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
732
+ Returns:
733
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
734
+ If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
735
+ tuple is returned where the first element is the sample tensor.
736
+ """
737
+ if self.num_inference_steps is None:
738
+ raise ValueError(
739
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
740
+ )
741
+
742
+ if self.step_index is None:
743
+ self._init_step_index(timestep)
744
+
745
+ # Improve numerical stability for small number of steps
746
+ lower_order_final = (self.step_index == len(self.timesteps) - 1) and (
747
+ self.config.euler_at_final or
748
+ (self.config.lower_order_final and len(self.timesteps) < 15) or
749
+ self.config.final_sigmas_type == "zero")
750
+ lower_order_second = ((self.step_index == len(self.timesteps) - 2) and
751
+ self.config.lower_order_final and
752
+ len(self.timesteps) < 15)
753
+
754
+ model_output = self.convert_model_output(model_output, sample=sample)
755
+ for i in range(self.config.solver_order - 1):
756
+ self.model_outputs[i] = self.model_outputs[i + 1]
757
+ self.model_outputs[-1] = model_output
758
+
759
+ # Upcast to avoid precision issues when computing prev_sample
760
+ sample = sample.to(torch.float32)
761
+ if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"
762
+ ] and variance_noise is None:
763
+ noise = randn_tensor(
764
+ model_output.shape,
765
+ generator=generator,
766
+ device=model_output.device,
767
+ dtype=torch.float32)
768
+ elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
769
+ noise = variance_noise.to(
770
+ device=model_output.device,
771
+ dtype=torch.float32) # pyright: ignore
772
+ else:
773
+ noise = None
774
+
775
+ if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
776
+ prev_sample = self.dpm_solver_first_order_update(
777
+ model_output, sample=sample, noise=noise)
778
+ elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
779
+ prev_sample = self.multistep_dpm_solver_second_order_update(
780
+ self.model_outputs, sample=sample, noise=noise)
781
+ else:
782
+ prev_sample = self.multistep_dpm_solver_third_order_update(
783
+ self.model_outputs, sample=sample)
784
+
785
+ if self.lower_order_nums < self.config.solver_order:
786
+ self.lower_order_nums += 1
787
+
788
+ # Cast sample back to expected dtype
789
+ prev_sample = prev_sample.to(model_output.dtype)
790
+
791
+ # upon completion increase step index by one
792
+ self._step_index += 1 # pyright: ignore
793
+
794
+ if not return_dict:
795
+ return (prev_sample,)
796
+
797
+ return SchedulerOutput(prev_sample=prev_sample)
798
+
799
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input
800
+ def scale_model_input(self, sample: torch.Tensor, *args,
801
+ **kwargs) -> torch.Tensor:
802
+ """
803
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
804
+ current timestep.
805
+ Args:
806
+ sample (`torch.Tensor`):
807
+ The input sample.
808
+ Returns:
809
+ `torch.Tensor`:
810
+ A scaled input sample.
811
+ """
812
+ return sample
813
+
814
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input
815
+ def add_noise(
816
+ self,
817
+ original_samples: torch.Tensor,
818
+ noise: torch.Tensor,
819
+ timesteps: torch.IntTensor,
820
+ ) -> torch.Tensor:
821
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
822
+ sigmas = self.sigmas.to(
823
+ device=original_samples.device, dtype=original_samples.dtype)
824
+ if original_samples.device.type == "mps" and torch.is_floating_point(
825
+ timesteps):
826
+ # mps does not support float64
827
+ schedule_timesteps = self.timesteps.to(
828
+ original_samples.device, dtype=torch.float32)
829
+ timesteps = timesteps.to(
830
+ original_samples.device, dtype=torch.float32)
831
+ else:
832
+ schedule_timesteps = self.timesteps.to(original_samples.device)
833
+ timesteps = timesteps.to(original_samples.device)
834
+
835
+ # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
836
+ if self.begin_index is None:
837
+ step_indices = [
838
+ self.index_for_timestep(t, schedule_timesteps)
839
+ for t in timesteps
840
+ ]
841
+ elif self.step_index is not None:
842
+ # add_noise is called after first denoising step (for inpainting)
843
+ step_indices = [self.step_index] * timesteps.shape[0]
844
+ else:
845
+ # add noise is called before first denoising step to create initial latent(img2img)
846
+ step_indices = [self.begin_index] * timesteps.shape[0]
847
+
848
+ sigma = sigmas[step_indices].flatten()
849
+ while len(sigma.shape) < len(original_samples.shape):
850
+ sigma = sigma.unsqueeze(-1)
851
+
852
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
853
+ noisy_samples = alpha_t * original_samples + sigma_t * noise
854
+ return noisy_samples
855
+
856
+ def __len__(self):
857
+ return self.config.num_train_timesteps
humo/models/utils/fm_solvers_unipc.py ADDED
@@ -0,0 +1,800 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py
2
+ # Convert unipc for flow matching
3
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
4
+
5
+ import math
6
+ from typing import List, Optional, Tuple, Union
7
+
8
+ import numpy as np
9
+ import torch
10
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
11
+ from diffusers.schedulers.scheduling_utils import (KarrasDiffusionSchedulers,
12
+ SchedulerMixin,
13
+ SchedulerOutput)
14
+ from diffusers.utils import deprecate, is_scipy_available
15
+
16
+ if is_scipy_available():
17
+ import scipy.stats
18
+
19
+
20
+ class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
21
+ """
22
+ `UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models.
23
+
24
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
25
+ methods the library implements for all schedulers such as loading and saving.
26
+
27
+ Args:
28
+ num_train_timesteps (`int`, defaults to 1000):
29
+ The number of diffusion steps to train the model.
30
+ solver_order (`int`, default `2`):
31
+ The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1`
32
+ due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for
33
+ unconditional sampling.
34
+ prediction_type (`str`, defaults to "flow_prediction"):
35
+ Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts
36
+ the flow of the diffusion process.
37
+ thresholding (`bool`, defaults to `False`):
38
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
39
+ as Stable Diffusion.
40
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
41
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
42
+ sample_max_value (`float`, defaults to 1.0):
43
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`.
44
+ predict_x0 (`bool`, defaults to `True`):
45
+ Whether to use the updating algorithm on the predicted x0.
46
+ solver_type (`str`, default `bh2`):
47
+ Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2`
48
+ otherwise.
49
+ lower_order_final (`bool`, default `True`):
50
+ Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
51
+ stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
52
+ disable_corrector (`list`, default `[]`):
53
+ Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)`
54
+ and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is
55
+ usually disabled during the first few steps.
56
+ solver_p (`SchedulerMixin`, default `None`):
57
+ Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`.
58
+ use_karras_sigmas (`bool`, *optional*, defaults to `False`):
59
+ Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
60
+ the sigmas are determined according to a sequence of noise levels {σi}.
61
+ use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
62
+ Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
63
+ timestep_spacing (`str`, defaults to `"linspace"`):
64
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
65
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
66
+ steps_offset (`int`, defaults to 0):
67
+ An offset added to the inference steps, as required by some model families.
68
+ final_sigmas_type (`str`, defaults to `"zero"`):
69
+ The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
70
+ sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
71
+ """
72
+
73
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
74
+ order = 1
75
+
76
+ @register_to_config
77
+ def __init__(
78
+ self,
79
+ num_train_timesteps: int = 1000,
80
+ solver_order: int = 2,
81
+ prediction_type: str = "flow_prediction",
82
+ shift: Optional[float] = 1.0,
83
+ use_dynamic_shifting=False,
84
+ thresholding: bool = False,
85
+ dynamic_thresholding_ratio: float = 0.995,
86
+ sample_max_value: float = 1.0,
87
+ predict_x0: bool = True,
88
+ solver_type: str = "bh2",
89
+ lower_order_final: bool = True,
90
+ disable_corrector: List[int] = [],
91
+ solver_p: SchedulerMixin = None,
92
+ timestep_spacing: str = "linspace",
93
+ steps_offset: int = 0,
94
+ final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
95
+ ):
96
+
97
+ if solver_type not in ["bh1", "bh2"]:
98
+ if solver_type in ["midpoint", "heun", "logrho"]:
99
+ self.register_to_config(solver_type="bh2")
100
+ else:
101
+ raise NotImplementedError(
102
+ f"{solver_type} is not implemented for {self.__class__}")
103
+
104
+ self.predict_x0 = predict_x0
105
+ # setable values
106
+ self.num_inference_steps = None
107
+ alphas = np.linspace(1, 1 / num_train_timesteps,
108
+ num_train_timesteps)[::-1].copy()
109
+ sigmas = 1.0 - alphas
110
+ sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32)
111
+
112
+ if not use_dynamic_shifting:
113
+ # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
114
+ sigmas = shift * sigmas / (1 +
115
+ (shift - 1) * sigmas) # pyright: ignore
116
+
117
+ self.sigmas = sigmas
118
+ self.timesteps = sigmas * num_train_timesteps
119
+
120
+ self.model_outputs = [None] * solver_order
121
+ self.timestep_list = [None] * solver_order
122
+ self.lower_order_nums = 0
123
+ self.disable_corrector = disable_corrector
124
+ self.solver_p = solver_p
125
+ self.last_sample = None
126
+ self._step_index = None
127
+ self._begin_index = None
128
+
129
+ self.sigmas = self.sigmas.to(
130
+ "cpu") # to avoid too much CPU/GPU communication
131
+ self.sigma_min = self.sigmas[-1].item()
132
+ self.sigma_max = self.sigmas[0].item()
133
+
134
+ @property
135
+ def step_index(self):
136
+ """
137
+ The index counter for current timestep. It will increase 1 after each scheduler step.
138
+ """
139
+ return self._step_index
140
+
141
+ @property
142
+ def begin_index(self):
143
+ """
144
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
145
+ """
146
+ return self._begin_index
147
+
148
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
149
+ def set_begin_index(self, begin_index: int = 0):
150
+ """
151
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
152
+
153
+ Args:
154
+ begin_index (`int`):
155
+ The begin index for the scheduler.
156
+ """
157
+ self._begin_index = begin_index
158
+
159
+ # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps
160
+ def set_timesteps(
161
+ self,
162
+ num_inference_steps: Union[int, None] = None,
163
+ device: Union[str, torch.device] = None,
164
+ sigmas: Optional[List[float]] = None,
165
+ mu: Optional[Union[float, None]] = None,
166
+ shift: Optional[Union[float, None]] = None,
167
+ ):
168
+ """
169
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
170
+ Args:
171
+ num_inference_steps (`int`):
172
+ Total number of the spacing of the time steps.
173
+ device (`str` or `torch.device`, *optional*):
174
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
175
+ """
176
+
177
+ if self.config.use_dynamic_shifting and mu is None:
178
+ raise ValueError(
179
+ " you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`"
180
+ )
181
+
182
+ if sigmas is None:
183
+ sigmas = np.linspace(self.sigma_max, self.sigma_min,
184
+ num_inference_steps +
185
+ 1).copy()[:-1] # pyright: ignore
186
+
187
+ if self.config.use_dynamic_shifting:
188
+ sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore
189
+ else:
190
+ if shift is None:
191
+ shift = self.config.shift
192
+ sigmas = shift * sigmas / (1 +
193
+ (shift - 1) * sigmas) # pyright: ignore
194
+
195
+ if self.config.final_sigmas_type == "sigma_min":
196
+ sigma_last = ((1 - self.alphas_cumprod[0]) /
197
+ self.alphas_cumprod[0])**0.5
198
+ elif self.config.final_sigmas_type == "zero":
199
+ sigma_last = 0
200
+ else:
201
+ raise ValueError(
202
+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
203
+ )
204
+
205
+ timesteps = sigmas * self.config.num_train_timesteps
206
+ sigmas = np.concatenate([sigmas, [sigma_last]
207
+ ]).astype(np.float32) # pyright: ignore
208
+
209
+ self.sigmas = torch.from_numpy(sigmas)
210
+ self.timesteps = torch.from_numpy(timesteps).to(
211
+ device=device, dtype=torch.int64)
212
+
213
+ self.num_inference_steps = len(timesteps)
214
+
215
+ self.model_outputs = [
216
+ None,
217
+ ] * self.config.solver_order
218
+ self.lower_order_nums = 0
219
+ self.last_sample = None
220
+ if self.solver_p:
221
+ self.solver_p.set_timesteps(self.num_inference_steps, device=device)
222
+
223
+ # add an index counter for schedulers that allow duplicated timesteps
224
+ self._step_index = None
225
+ self._begin_index = None
226
+ self.sigmas = self.sigmas.to(
227
+ "cpu") # to avoid too much CPU/GPU communication
228
+
229
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
230
+ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
231
+ """
232
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
233
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
234
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
235
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
236
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
237
+
238
+ https://arxiv.org/abs/2205.11487
239
+ """
240
+ dtype = sample.dtype
241
+ batch_size, channels, *remaining_dims = sample.shape
242
+
243
+ if dtype not in (torch.float32, torch.float64):
244
+ sample = sample.float(
245
+ ) # upcast for quantile calculation, and clamp not implemented for cpu half
246
+
247
+ # Flatten sample for doing quantile calculation along each image
248
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
249
+
250
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
251
+
252
+ s = torch.quantile(
253
+ abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
254
+ s = torch.clamp(
255
+ s, min=1, max=self.config.sample_max_value
256
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
257
+ s = s.unsqueeze(
258
+ 1) # (batch_size, 1) because clamp will broadcast along dim=0
259
+ sample = torch.clamp(
260
+ sample, -s, s
261
+ ) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
262
+
263
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
264
+ sample = sample.to(dtype)
265
+
266
+ return sample
267
+
268
+ # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t
269
+ def _sigma_to_t(self, sigma):
270
+ return sigma * self.config.num_train_timesteps
271
+
272
+ def _sigma_to_alpha_sigma_t(self, sigma):
273
+ return 1 - sigma, sigma
274
+
275
+ # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps
276
+ def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
277
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma)
278
+
279
+ def convert_model_output(
280
+ self,
281
+ model_output: torch.Tensor,
282
+ *args,
283
+ sample: torch.Tensor = None,
284
+ **kwargs,
285
+ ) -> torch.Tensor:
286
+ r"""
287
+ Convert the model output to the corresponding type the UniPC algorithm needs.
288
+
289
+ Args:
290
+ model_output (`torch.Tensor`):
291
+ The direct output from the learned diffusion model.
292
+ timestep (`int`):
293
+ The current discrete timestep in the diffusion chain.
294
+ sample (`torch.Tensor`):
295
+ A current instance of a sample created by the diffusion process.
296
+
297
+ Returns:
298
+ `torch.Tensor`:
299
+ The converted model output.
300
+ """
301
+ timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
302
+ if sample is None:
303
+ if len(args) > 1:
304
+ sample = args[1]
305
+ else:
306
+ raise ValueError(
307
+ "missing `sample` as a required keyward argument")
308
+ if timestep is not None:
309
+ deprecate(
310
+ "timesteps",
311
+ "1.0.0",
312
+ "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
313
+ )
314
+
315
+ sigma = self.sigmas[self.step_index]
316
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
317
+
318
+ if self.predict_x0:
319
+ if self.config.prediction_type == "flow_prediction":
320
+ sigma_t = self.sigmas[self.step_index]
321
+ x0_pred = sample - sigma_t * model_output
322
+ else:
323
+ raise ValueError(
324
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
325
+ " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler."
326
+ )
327
+
328
+ if self.config.thresholding:
329
+ x0_pred = self._threshold_sample(x0_pred)
330
+
331
+ return x0_pred
332
+ else:
333
+ if self.config.prediction_type == "flow_prediction":
334
+ sigma_t = self.sigmas[self.step_index]
335
+ epsilon = sample - (1 - sigma_t) * model_output
336
+ else:
337
+ raise ValueError(
338
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
339
+ " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler."
340
+ )
341
+
342
+ if self.config.thresholding:
343
+ sigma_t = self.sigmas[self.step_index]
344
+ x0_pred = sample - sigma_t * model_output
345
+ x0_pred = self._threshold_sample(x0_pred)
346
+ epsilon = model_output + x0_pred
347
+
348
+ return epsilon
349
+
350
+ def multistep_uni_p_bh_update(
351
+ self,
352
+ model_output: torch.Tensor,
353
+ *args,
354
+ sample: torch.Tensor = None,
355
+ order: int = None, # pyright: ignore
356
+ **kwargs,
357
+ ) -> torch.Tensor:
358
+ """
359
+ One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified.
360
+
361
+ Args:
362
+ model_output (`torch.Tensor`):
363
+ The direct output from the learned diffusion model at the current timestep.
364
+ prev_timestep (`int`):
365
+ The previous discrete timestep in the diffusion chain.
366
+ sample (`torch.Tensor`):
367
+ A current instance of a sample created by the diffusion process.
368
+ order (`int`):
369
+ The order of UniP at this timestep (corresponds to the *p* in UniPC-p).
370
+
371
+ Returns:
372
+ `torch.Tensor`:
373
+ The sample tensor at the previous timestep.
374
+ """
375
+ prev_timestep = args[0] if len(args) > 0 else kwargs.pop(
376
+ "prev_timestep", None)
377
+ if sample is None:
378
+ if len(args) > 1:
379
+ sample = args[1]
380
+ else:
381
+ raise ValueError(
382
+ " missing `sample` as a required keyward argument")
383
+ if order is None:
384
+ if len(args) > 2:
385
+ order = args[2]
386
+ else:
387
+ raise ValueError(
388
+ " missing `order` as a required keyward argument")
389
+ if prev_timestep is not None:
390
+ deprecate(
391
+ "prev_timestep",
392
+ "1.0.0",
393
+ "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
394
+ )
395
+ model_output_list = self.model_outputs
396
+
397
+ s0 = self.timestep_list[-1]
398
+ m0 = model_output_list[-1]
399
+ x = sample
400
+
401
+ if self.solver_p:
402
+ x_t = self.solver_p.step(model_output, s0, x).prev_sample
403
+ return x_t
404
+
405
+ sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[
406
+ self.step_index] # pyright: ignore
407
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
408
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
409
+
410
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
411
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
412
+
413
+ h = lambda_t - lambda_s0
414
+ device = sample.device
415
+
416
+ rks = []
417
+ D1s = []
418
+ for i in range(1, order):
419
+ si = self.step_index - i # pyright: ignore
420
+ mi = model_output_list[-(i + 1)]
421
+ alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
422
+ lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
423
+ rk = (lambda_si - lambda_s0) / h
424
+ rks.append(rk)
425
+ D1s.append((mi - m0) / rk) # pyright: ignore
426
+
427
+ rks.append(1.0)
428
+ rks = torch.tensor(rks, device=device)
429
+
430
+ R = []
431
+ b = []
432
+
433
+ hh = -h if self.predict_x0 else h
434
+ h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
435
+ h_phi_k = h_phi_1 / hh - 1
436
+
437
+ factorial_i = 1
438
+
439
+ if self.config.solver_type == "bh1":
440
+ B_h = hh
441
+ elif self.config.solver_type == "bh2":
442
+ B_h = torch.expm1(hh)
443
+ else:
444
+ raise NotImplementedError()
445
+
446
+ for i in range(1, order + 1):
447
+ R.append(torch.pow(rks, i - 1))
448
+ b.append(h_phi_k * factorial_i / B_h)
449
+ factorial_i *= i + 1
450
+ h_phi_k = h_phi_k / hh - 1 / factorial_i
451
+
452
+ R = torch.stack(R)
453
+ b = torch.tensor(b, device=device)
454
+
455
+ if len(D1s) > 0:
456
+ D1s = torch.stack(D1s, dim=1) # (B, K)
457
+ # for order 2, we use a simplified version
458
+ if order == 2:
459
+ rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
460
+ else:
461
+ rhos_p = torch.linalg.solve(R[:-1, :-1],
462
+ b[:-1]).to(device).to(x.dtype)
463
+ else:
464
+ D1s = None
465
+
466
+ if self.predict_x0:
467
+ x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
468
+ if D1s is not None:
469
+ pred_res = torch.einsum("k,bkc...->bc...", rhos_p,
470
+ D1s) # pyright: ignore
471
+ else:
472
+ pred_res = 0
473
+ x_t = x_t_ - alpha_t * B_h * pred_res
474
+ else:
475
+ x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
476
+ if D1s is not None:
477
+ pred_res = torch.einsum("k,bkc...->bc...", rhos_p,
478
+ D1s) # pyright: ignore
479
+ else:
480
+ pred_res = 0
481
+ x_t = x_t_ - sigma_t * B_h * pred_res
482
+
483
+ x_t = x_t.to(x.dtype)
484
+ return x_t
485
+
486
+ def multistep_uni_c_bh_update(
487
+ self,
488
+ this_model_output: torch.Tensor,
489
+ *args,
490
+ last_sample: torch.Tensor = None,
491
+ this_sample: torch.Tensor = None,
492
+ order: int = None, # pyright: ignore
493
+ **kwargs,
494
+ ) -> torch.Tensor:
495
+ """
496
+ One step for the UniC (B(h) version).
497
+
498
+ Args:
499
+ this_model_output (`torch.Tensor`):
500
+ The model outputs at `x_t`.
501
+ this_timestep (`int`):
502
+ The current timestep `t`.
503
+ last_sample (`torch.Tensor`):
504
+ The generated sample before the last predictor `x_{t-1}`.
505
+ this_sample (`torch.Tensor`):
506
+ The generated sample after the last predictor `x_{t}`.
507
+ order (`int`):
508
+ The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`.
509
+
510
+ Returns:
511
+ `torch.Tensor`:
512
+ The corrected sample tensor at the current timestep.
513
+ """
514
+ this_timestep = args[0] if len(args) > 0 else kwargs.pop(
515
+ "this_timestep", None)
516
+ if last_sample is None:
517
+ if len(args) > 1:
518
+ last_sample = args[1]
519
+ else:
520
+ raise ValueError(
521
+ " missing`last_sample` as a required keyward argument")
522
+ if this_sample is None:
523
+ if len(args) > 2:
524
+ this_sample = args[2]
525
+ else:
526
+ raise ValueError(
527
+ " missing`this_sample` as a required keyward argument")
528
+ if order is None:
529
+ if len(args) > 3:
530
+ order = args[3]
531
+ else:
532
+ raise ValueError(
533
+ " missing`order` as a required keyward argument")
534
+ if this_timestep is not None:
535
+ deprecate(
536
+ "this_timestep",
537
+ "1.0.0",
538
+ "Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
539
+ )
540
+
541
+ model_output_list = self.model_outputs
542
+
543
+ m0 = model_output_list[-1]
544
+ x = last_sample
545
+ x_t = this_sample
546
+ model_t = this_model_output
547
+
548
+ sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[
549
+ self.step_index - 1] # pyright: ignore
550
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
551
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
552
+
553
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
554
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
555
+
556
+ h = lambda_t - lambda_s0
557
+ device = this_sample.device
558
+
559
+ rks = []
560
+ D1s = []
561
+ for i in range(1, order):
562
+ si = self.step_index - (i + 1) # pyright: ignore
563
+ mi = model_output_list[-(i + 1)]
564
+ alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
565
+ lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
566
+ rk = (lambda_si - lambda_s0) / h
567
+ rks.append(rk)
568
+ D1s.append((mi - m0) / rk) # pyright: ignore
569
+
570
+ rks.append(1.0)
571
+ rks = torch.tensor(rks, device=device)
572
+
573
+ R = []
574
+ b = []
575
+
576
+ hh = -h if self.predict_x0 else h
577
+ h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
578
+ h_phi_k = h_phi_1 / hh - 1
579
+
580
+ factorial_i = 1
581
+
582
+ if self.config.solver_type == "bh1":
583
+ B_h = hh
584
+ elif self.config.solver_type == "bh2":
585
+ B_h = torch.expm1(hh)
586
+ else:
587
+ raise NotImplementedError()
588
+
589
+ for i in range(1, order + 1):
590
+ R.append(torch.pow(rks, i - 1))
591
+ b.append(h_phi_k * factorial_i / B_h)
592
+ factorial_i *= i + 1
593
+ h_phi_k = h_phi_k / hh - 1 / factorial_i
594
+
595
+ R = torch.stack(R)
596
+ b = torch.tensor(b, device=device)
597
+
598
+ if len(D1s) > 0:
599
+ D1s = torch.stack(D1s, dim=1)
600
+ else:
601
+ D1s = None
602
+
603
+ # for order 1, we use a simplified version
604
+ if order == 1:
605
+ rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
606
+ else:
607
+ rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype)
608
+
609
+ if self.predict_x0:
610
+ x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
611
+ if D1s is not None:
612
+ corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
613
+ else:
614
+ corr_res = 0
615
+ D1_t = model_t - m0
616
+ x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)
617
+ else:
618
+ x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
619
+ if D1s is not None:
620
+ corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
621
+ else:
622
+ corr_res = 0
623
+ D1_t = model_t - m0
624
+ x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t)
625
+ x_t = x_t.to(x.dtype)
626
+ return x_t
627
+
628
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
629
+ if schedule_timesteps is None:
630
+ schedule_timesteps = self.timesteps
631
+
632
+ indices = (schedule_timesteps == timestep).nonzero()
633
+
634
+ # The sigma index that is taken for the **very** first `step`
635
+ # is always the second index (or the last index if there is only 1)
636
+ # This way we can ensure we don't accidentally skip a sigma in
637
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
638
+ pos = 1 if len(indices) > 1 else 0
639
+
640
+ return indices[pos].item()
641
+
642
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
643
+ def _init_step_index(self, timestep):
644
+ """
645
+ Initialize the step_index counter for the scheduler.
646
+ """
647
+
648
+ if self.begin_index is None:
649
+ if isinstance(timestep, torch.Tensor):
650
+ timestep = timestep.to(self.timesteps.device)
651
+ self._step_index = self.index_for_timestep(timestep)
652
+ else:
653
+ self._step_index = self._begin_index
654
+
655
+ def step(self,
656
+ model_output: torch.Tensor,
657
+ timestep: Union[int, torch.Tensor],
658
+ sample: torch.Tensor,
659
+ return_dict: bool = True,
660
+ generator=None) -> Union[SchedulerOutput, Tuple]:
661
+ """
662
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
663
+ the multistep UniPC.
664
+
665
+ Args:
666
+ model_output (`torch.Tensor`):
667
+ The direct output from learned diffusion model.
668
+ timestep (`int`):
669
+ The current discrete timestep in the diffusion chain.
670
+ sample (`torch.Tensor`):
671
+ A current instance of a sample created by the diffusion process.
672
+ return_dict (`bool`):
673
+ Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
674
+
675
+ Returns:
676
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
677
+ If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
678
+ tuple is returned where the first element is the sample tensor.
679
+
680
+ """
681
+ if self.num_inference_steps is None:
682
+ raise ValueError(
683
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
684
+ )
685
+
686
+ if self.step_index is None:
687
+ self._init_step_index(timestep)
688
+
689
+ use_corrector = (
690
+ self.step_index > 0 and
691
+ self.step_index - 1 not in self.disable_corrector and
692
+ self.last_sample is not None # pyright: ignore
693
+ )
694
+
695
+ model_output_convert = self.convert_model_output(
696
+ model_output, sample=sample)
697
+ if use_corrector:
698
+ sample = self.multistep_uni_c_bh_update(
699
+ this_model_output=model_output_convert,
700
+ last_sample=self.last_sample,
701
+ this_sample=sample,
702
+ order=self.this_order,
703
+ )
704
+
705
+ for i in range(self.config.solver_order - 1):
706
+ self.model_outputs[i] = self.model_outputs[i + 1]
707
+ self.timestep_list[i] = self.timestep_list[i + 1]
708
+
709
+ self.model_outputs[-1] = model_output_convert
710
+ self.timestep_list[-1] = timestep # pyright: ignore
711
+
712
+ if self.config.lower_order_final:
713
+ this_order = min(self.config.solver_order,
714
+ len(self.timesteps) -
715
+ self.step_index) # pyright: ignore
716
+ else:
717
+ this_order = self.config.solver_order
718
+
719
+ self.this_order = min(this_order,
720
+ self.lower_order_nums + 1) # warmup for multistep
721
+ assert self.this_order > 0
722
+
723
+ self.last_sample = sample
724
+ prev_sample = self.multistep_uni_p_bh_update(
725
+ model_output=model_output, # pass the original non-converted model output, in case solver-p is used
726
+ sample=sample,
727
+ order=self.this_order,
728
+ )
729
+
730
+ if self.lower_order_nums < self.config.solver_order:
731
+ self.lower_order_nums += 1
732
+
733
+ # upon completion increase step index by one
734
+ self._step_index += 1 # pyright: ignore
735
+
736
+ if not return_dict:
737
+ return (prev_sample,)
738
+
739
+ return SchedulerOutput(prev_sample=prev_sample)
740
+
741
+ def scale_model_input(self, sample: torch.Tensor, *args,
742
+ **kwargs) -> torch.Tensor:
743
+ """
744
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
745
+ current timestep.
746
+
747
+ Args:
748
+ sample (`torch.Tensor`):
749
+ The input sample.
750
+
751
+ Returns:
752
+ `torch.Tensor`:
753
+ A scaled input sample.
754
+ """
755
+ return sample
756
+
757
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
758
+ def add_noise(
759
+ self,
760
+ original_samples: torch.Tensor,
761
+ noise: torch.Tensor,
762
+ timesteps: torch.IntTensor,
763
+ ) -> torch.Tensor:
764
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
765
+ sigmas = self.sigmas.to(
766
+ device=original_samples.device, dtype=original_samples.dtype)
767
+ if original_samples.device.type == "mps" and torch.is_floating_point(
768
+ timesteps):
769
+ # mps does not support float64
770
+ schedule_timesteps = self.timesteps.to(
771
+ original_samples.device, dtype=torch.float32)
772
+ timesteps = timesteps.to(
773
+ original_samples.device, dtype=torch.float32)
774
+ else:
775
+ schedule_timesteps = self.timesteps.to(original_samples.device)
776
+ timesteps = timesteps.to(original_samples.device)
777
+
778
+ # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
779
+ if self.begin_index is None:
780
+ step_indices = [
781
+ self.index_for_timestep(t, schedule_timesteps)
782
+ for t in timesteps
783
+ ]
784
+ elif self.step_index is not None:
785
+ # add_noise is called after first denoising step (for inpainting)
786
+ step_indices = [self.step_index] * timesteps.shape[0]
787
+ else:
788
+ # add noise is called before first denoising step to create initial latent(img2img)
789
+ step_indices = [self.begin_index] * timesteps.shape[0]
790
+
791
+ sigma = sigmas[step_indices].flatten()
792
+ while len(sigma.shape) < len(original_samples.shape):
793
+ sigma = sigma.unsqueeze(-1)
794
+
795
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
796
+ noisy_samples = alpha_t * original_samples + sigma_t * noise
797
+ return noisy_samples
798
+
799
+ def __len__(self):
800
+ return self.config.num_train_timesteps
humo/models/utils/utils.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import argparse
3
+ import binascii
4
+ import os
5
+ import os.path as osp
6
+ import json
7
+ from omegaconf import OmegaConf
8
+
9
+ import imageio
10
+ import torch
11
+ import torchvision
12
+ from moviepy.editor import AudioFileClip, VideoClip
13
+
14
+ __all__ = ['tensor_to_video', 'prepare_json_dataset']
15
+
16
+
17
+ def tensor_to_video(tensor, output_video_path, input_audio_path, fps=25):
18
+ """
19
+ Converts a Tensor with shape [c, f, h, w] into a video and adds an audio track from the specified audio file.
20
+
21
+ Args:
22
+ tensor (numpy): The Tensor to be converted, shaped [f, h, w, c].
23
+ output_video_path (str): The file path where the output video will be saved.
24
+ input_audio_path (str): The path to the audio file (WAV file) that contains the audio track to be added.
25
+ fps (int): The frame rate of the output video. Default is 30 fps.
26
+ """
27
+ def make_frame(t):
28
+ frame_index = min(int(t * fps), tensor.shape[0] - 1)
29
+ return tensor[frame_index]
30
+
31
+ video_duration = tensor.shape[0] / fps
32
+ audio_clip = AudioFileClip(input_audio_path)
33
+ audio_duration = audio_clip.duration
34
+ final_duration = min(video_duration, audio_duration)
35
+ audio_clip = audio_clip.subclip(0, final_duration)
36
+ new_video_clip = VideoClip(make_frame, duration=final_duration)
37
+ new_video_clip = new_video_clip.set_audio(audio_clip)
38
+ new_video_clip.write_videofile(output_video_path, fps=fps, audio_codec="aac")
39
+
40
+
41
+ def prepare_json_dataset(json_path):
42
+ samples = []
43
+ with open(json_path, "rb") as f:
44
+ data = json.load(f)
45
+ for itemname, row in data.items():
46
+ text = row['prompt'].strip().replace("_", " ").strip('"')
47
+ audio_path = row['audio_path']
48
+ ref_img_path = [x for x in row['img_paths']]
49
+
50
+ samples.append({
51
+ "text": text,
52
+ "ref_img": ref_img_path,
53
+ "audio": audio_path,
54
+ "itemname": itemname
55
+ })
56
+ samples = OmegaConf.create(samples)
57
+
58
+ return samples
humo/models/wan_modules/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .attention import flash_attention
2
+ from .model import WanModel
3
+ from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model
4
+ from .tokenizers import HuggingfaceTokenizer
5
+ from .vae import WanVAE
6
+
7
+ __all__ = [
8
+ 'WanVAE',
9
+ 'WanModel',
10
+ 'T5Model',
11
+ 'T5Encoder',
12
+ 'T5Decoder',
13
+ 'T5EncoderModel',
14
+ 'HuggingfaceTokenizer',
15
+ 'flash_attention',
16
+ ]
humo/models/wan_modules/attention.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import warnings
3
+ import torch
4
+ from typing import Optional, Tuple
5
+
6
+ try:
7
+ import flash_attn_interface
8
+ FLASH_ATTN_3_AVAILABLE = True
9
+ except ModuleNotFoundError:
10
+ FLASH_ATTN_3_AVAILABLE = False
11
+
12
+ try:
13
+ import flash_attn
14
+ FLASH_ATTN_2_AVAILABLE = True
15
+ except ModuleNotFoundError:
16
+ FLASH_ATTN_2_AVAILABLE = False
17
+
18
+
19
+ __all__ = [
20
+ 'flash_attention',
21
+ 'attention',
22
+ ]
23
+
24
+
25
+ # ---------------------------
26
+ # Custom op + fake kernel
27
+ # ---------------------------
28
+ from typing import Optional, Sequence # <- add Sequence
29
+
30
+ # ... imports unchanged ...
31
+ from typing import Optional, Sequence
32
+
33
+ @torch.library.custom_op("wan::flash_attention", mutates_args=())
34
+ def _wan_flash_attention_op(
35
+ q: torch.Tensor,
36
+ k: torch.Tensor,
37
+ v: torch.Tensor,
38
+ q_lens: Optional[torch.Tensor] = None,
39
+ k_lens: Optional[torch.Tensor] = None,
40
+ dropout_p: float = 0.0,
41
+ softmax_scale: Optional[float] = None,
42
+ q_scale: Optional[float] = None,
43
+ causal: bool = False,
44
+ # IMPORTANT: schema-friendly default (None), not a tuple
45
+ window_size: Optional[Sequence[int]] = None,
46
+ deterministic: bool = False,
47
+ dtype: torch.dtype = torch.bfloat16,
48
+ version: Optional[int] = None,
49
+ ) -> torch.Tensor:
50
+ half_dtypes = (torch.float16, torch.bfloat16)
51
+ assert dtype in half_dtypes
52
+ assert q.size(-1) <= 256
53
+
54
+ # normalize window_size to a 2-tuple for FA2 API
55
+ if window_size is None:
56
+ ws = (-1, -1)
57
+ else:
58
+ ws = tuple(window_size)
59
+ if len(ws) != 2:
60
+ raise ValueError(f"window_size must have length 2; got {window_size!r}")
61
+
62
+ b, lq, nheads = q.shape[0], q.shape[1], q.shape[2]
63
+ lk = k.shape[1]
64
+ out_dtype = q.dtype
65
+
66
+ def half(x: torch.Tensor) -> torch.Tensor:
67
+ return x if x.dtype in half_dtypes else x.to(dtype)
68
+
69
+ # --- preprocess (unchanged) ---
70
+ if q_lens is None:
71
+ q_flat = half(q.flatten(0, 1))
72
+ q_lens = torch.tensor([lq] * b, dtype=torch.int32)
73
+ else:
74
+ q_flat = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
75
+
76
+ if k_lens is None:
77
+ k_flat = half(k.flatten(0, 1))
78
+ v_flat = half(v.flatten(0, 1))
79
+ k_lens = torch.tensor([lk] * b, dtype=torch.int32)
80
+ else:
81
+ k_flat = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
82
+ v_flat = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
83
+
84
+ q_flat = q_flat.to(v_flat.dtype); k_flat = k_flat.to(v_flat.dtype)
85
+ if q_scale is not None:
86
+ q_flat = q_flat * q_scale
87
+
88
+ if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
89
+ warnings.warn('Flash attention 3 is not available, use flash attention 2 instead.')
90
+
91
+ if FLASH_ATTN_3_AVAILABLE:
92
+ ret = flash_attn_interface.flash_attn_varlen_func(
93
+ q=q_flat,
94
+ k=k_flat,
95
+ v=v_flat,
96
+ cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32).to(q_flat.device, non_blocking=True),
97
+ cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32).to(k_flat.device, non_blocking=True),
98
+ seqused_q=None,
99
+ seqused_k=None,
100
+ max_seqlen_q=lq,
101
+ max_seqlen_k=lk,
102
+ softmax_scale=softmax_scale,
103
+ causal=causal,
104
+ deterministic=deterministic,
105
+ )
106
+ out0 = ret[0] if isinstance(ret, (tuple, list)) else ret
107
+ total_q = b * lq
108
+ if out0.dim() != 3:
109
+ raise RuntimeError(f"Unexpected FA3 output rank {out0.dim()} shape={tuple(out0.shape)}")
110
+ if out0.shape[0] == total_q:
111
+ out_flat = out0
112
+ elif out0.shape[0] == nheads and out0.shape[1] == total_q:
113
+ out_flat = out0.transpose(0, 1).contiguous()
114
+ else:
115
+ raise RuntimeError(f"Unexpected FA3 output shape {tuple(out0.shape)}")
116
+ out = out_flat.unflatten(0, (b, lq))
117
+
118
+ elif FLASH_ATTN_2_AVAILABLE:
119
+ out = flash_attn.flash_attn_varlen_func(
120
+ q=q_flat,
121
+ k=k_flat,
122
+ v=v_flat,
123
+ cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32).to(q_flat.device, non_blocking=True),
124
+ cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32).to(q_flat.device, non_blocking=True),
125
+ max_seqlen_q=lq,
126
+ max_seqlen_k=lk,
127
+ dropout_p=dropout_p,
128
+ softmax_scale=softmax_scale,
129
+ causal=causal,
130
+ window_size=ws, # <- pass 2-tuple
131
+ deterministic=deterministic,
132
+ ).unflatten(0, (b, lq))
133
+ else:
134
+ q_s = q.transpose(1, 2).to(dtype)
135
+ k_s = k.transpose(1, 2).to(dtype)
136
+ v_s = v.transpose(1, 2).to(dtype)
137
+ out = torch.nn.functional.scaled_dot_product_attention(
138
+ q_s, k_s, v_s, attn_mask=None, is_causal=causal, dropout_p=dropout_p
139
+ ).transpose(1, 2).contiguous()
140
+
141
+ return out.to(out_dtype)
142
+
143
+ @_wan_flash_attention_op.register_fake
144
+ def _wan_flash_attention_op_fake(
145
+ q,
146
+ k,
147
+ v,
148
+ q_lens=None,
149
+ k_lens=None,
150
+ dropout_p: float = 0.0,
151
+ softmax_scale=None,
152
+ q_scale=None,
153
+ causal: bool = False,
154
+ window_size: Optional[Sequence[int]] = None,
155
+ deterministic: bool = False,
156
+ dtype: torch.dtype = torch.bfloat16,
157
+ version: Optional[int] = None,
158
+ ):
159
+ # Match output shape: (B, Lq, Nq, Dh_v) and keep the SAME fake device as `q`
160
+ B, Lq, Nq, _ = q.shape
161
+ Dh_v = v.shape[-1]
162
+ return q.new_empty((B, Lq, Nq, Dh_v), dtype=q.dtype)
163
+
164
+
165
+
166
+ # ---------------------------
167
+ # Public API (unchanged signature)
168
+ # ---------------------------
169
+ def flash_attention(
170
+ q,
171
+ k,
172
+ v,
173
+ q_lens=None,
174
+ k_lens=None,
175
+ dropout_p=0.,
176
+ softmax_scale=None,
177
+ q_scale=None,
178
+ causal=False,
179
+ window_size=(-1, -1),
180
+ deterministic=False,
181
+ dtype=torch.bfloat16,
182
+ version=None,
183
+ ):
184
+ """
185
+ q: [B, Lq, Nq, C1].
186
+ k: [B, Lk, Nk, C1].
187
+ v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
188
+ q_lens: [B].
189
+ k_lens: [B].
190
+ dropout_p: float. Dropout probability.
191
+ softmax_scale: float. The scaling of QK^T before applying softmax.
192
+ causal: bool. Whether to apply causal attention mask.
193
+ window_size: (left right). If not (-1, -1), apply sliding window local attention.
194
+ deterministic: bool. If True, slightly slower and uses more memory.
195
+ dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
196
+ """
197
+ # Simply delegate to the custom op so Dynamo/AOT treats it as a single node;
198
+ # our eager kernel inside _wan_flash_attention_op keeps the original behavior.
199
+ return _wan_flash_attention_op(
200
+ q, k, v,
201
+ q_lens=q_lens,
202
+ k_lens=k_lens,
203
+ dropout_p=dropout_p,
204
+ softmax_scale=softmax_scale,
205
+ q_scale=q_scale,
206
+ causal=causal,
207
+ window_size=window_size,
208
+ deterministic=deterministic,
209
+ dtype=dtype,
210
+ version=version,
211
+ )
212
+
213
+
214
+ def attention(
215
+ q,
216
+ k,
217
+ v,
218
+ q_lens=None,
219
+ k_lens=None,
220
+ dropout_p=0.,
221
+ softmax_scale=None,
222
+ q_scale=None,
223
+ causal=False,
224
+ window_size=(-1, -1),
225
+ deterministic=False,
226
+ dtype=torch.bfloat16,
227
+ fa_version=None,
228
+ ):
229
+ if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
230
+ return flash_attention(
231
+ q=q,
232
+ k=k,
233
+ v=v,
234
+ q_lens=q_lens,
235
+ k_lens=k_lens,
236
+ dropout_p=dropout_p,
237
+ softmax_scale=softmax_scale,
238
+ q_scale=q_scale,
239
+ causal=causal,
240
+ window_size=window_size,
241
+ deterministic=deterministic,
242
+ dtype=dtype,
243
+ version=fa_version,
244
+ )
245
+ else:
246
+ if q_lens is not None or k_lens is not None:
247
+ warnings.warn(
248
+ 'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.'
249
+ )
250
+ q_ = q.transpose(1, 2).to(dtype)
251
+ k_ = k.transpose(1, 2).to(dtype)
252
+ v_ = v.transpose(1, 2).to(dtype)
253
+ out = torch.nn.functional.scaled_dot_product_attention(
254
+ q_, k_, v_, attn_mask=None, is_causal=causal, dropout_p=dropout_p
255
+ )
256
+ return out.transpose(1, 2).contiguous()
humo/models/wan_modules/clip.py ADDED
@@ -0,0 +1,542 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip''
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+ import logging
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torchvision.transforms as T
10
+
11
+ from .attention import flash_attention
12
+ from .tokenizers import HuggingfaceTokenizer
13
+ from .xlm_roberta import XLMRoberta
14
+
15
+ __all__ = [
16
+ 'XLMRobertaCLIP',
17
+ 'clip_xlm_roberta_vit_h_14',
18
+ 'CLIPModel',
19
+ ]
20
+
21
+
22
+ def pos_interpolate(pos, seq_len):
23
+ if pos.size(1) == seq_len:
24
+ return pos
25
+ else:
26
+ src_grid = int(math.sqrt(pos.size(1)))
27
+ tar_grid = int(math.sqrt(seq_len))
28
+ n = pos.size(1) - src_grid * src_grid
29
+ return torch.cat([
30
+ pos[:, :n],
31
+ F.interpolate(
32
+ pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(
33
+ 0, 3, 1, 2),
34
+ size=(tar_grid, tar_grid),
35
+ mode='bicubic',
36
+ align_corners=False).flatten(2).transpose(1, 2)
37
+ ],
38
+ dim=1)
39
+
40
+
41
+ class QuickGELU(nn.Module):
42
+
43
+ def forward(self, x):
44
+ return x * torch.sigmoid(1.702 * x)
45
+
46
+
47
+ class LayerNorm(nn.LayerNorm):
48
+
49
+ def forward(self, x):
50
+ return super().forward(x.float()).type_as(x)
51
+
52
+
53
+ class SelfAttention(nn.Module):
54
+
55
+ def __init__(self,
56
+ dim,
57
+ num_heads,
58
+ causal=False,
59
+ attn_dropout=0.0,
60
+ proj_dropout=0.0):
61
+ assert dim % num_heads == 0
62
+ super().__init__()
63
+ self.dim = dim
64
+ self.num_heads = num_heads
65
+ self.head_dim = dim // num_heads
66
+ self.causal = causal
67
+ self.attn_dropout = attn_dropout
68
+ self.proj_dropout = proj_dropout
69
+
70
+ # layers
71
+ self.to_qkv = nn.Linear(dim, dim * 3)
72
+ self.proj = nn.Linear(dim, dim)
73
+
74
+ def forward(self, x):
75
+ """
76
+ x: [B, L, C].
77
+ """
78
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
79
+
80
+ # compute query, key, value
81
+ q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
82
+
83
+ # compute attention
84
+ p = self.attn_dropout if self.training else 0.0
85
+ x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2)
86
+ x = x.reshape(b, s, c)
87
+
88
+ # output
89
+ x = self.proj(x)
90
+ x = F.dropout(x, self.proj_dropout, self.training)
91
+ return x
92
+
93
+
94
+ class SwiGLU(nn.Module):
95
+
96
+ def __init__(self, dim, mid_dim):
97
+ super().__init__()
98
+ self.dim = dim
99
+ self.mid_dim = mid_dim
100
+
101
+ # layers
102
+ self.fc1 = nn.Linear(dim, mid_dim)
103
+ self.fc2 = nn.Linear(dim, mid_dim)
104
+ self.fc3 = nn.Linear(mid_dim, dim)
105
+
106
+ def forward(self, x):
107
+ x = F.silu(self.fc1(x)) * self.fc2(x)
108
+ x = self.fc3(x)
109
+ return x
110
+
111
+
112
+ class AttentionBlock(nn.Module):
113
+
114
+ def __init__(self,
115
+ dim,
116
+ mlp_ratio,
117
+ num_heads,
118
+ post_norm=False,
119
+ causal=False,
120
+ activation='quick_gelu',
121
+ attn_dropout=0.0,
122
+ proj_dropout=0.0,
123
+ norm_eps=1e-5):
124
+ assert activation in ['quick_gelu', 'gelu', 'swi_glu']
125
+ super().__init__()
126
+ self.dim = dim
127
+ self.mlp_ratio = mlp_ratio
128
+ self.num_heads = num_heads
129
+ self.post_norm = post_norm
130
+ self.causal = causal
131
+ self.norm_eps = norm_eps
132
+
133
+ # layers
134
+ self.norm1 = LayerNorm(dim, eps=norm_eps)
135
+ self.attn = SelfAttention(dim, num_heads, causal, attn_dropout,
136
+ proj_dropout)
137
+ self.norm2 = LayerNorm(dim, eps=norm_eps)
138
+ if activation == 'swi_glu':
139
+ self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
140
+ else:
141
+ self.mlp = nn.Sequential(
142
+ nn.Linear(dim, int(dim * mlp_ratio)),
143
+ QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
144
+ nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
145
+
146
+ def forward(self, x):
147
+ if self.post_norm:
148
+ x = x + self.norm1(self.attn(x))
149
+ x = x + self.norm2(self.mlp(x))
150
+ else:
151
+ x = x + self.attn(self.norm1(x))
152
+ x = x + self.mlp(self.norm2(x))
153
+ return x
154
+
155
+
156
+ class AttentionPool(nn.Module):
157
+
158
+ def __init__(self,
159
+ dim,
160
+ mlp_ratio,
161
+ num_heads,
162
+ activation='gelu',
163
+ proj_dropout=0.0,
164
+ norm_eps=1e-5):
165
+ assert dim % num_heads == 0
166
+ super().__init__()
167
+ self.dim = dim
168
+ self.mlp_ratio = mlp_ratio
169
+ self.num_heads = num_heads
170
+ self.head_dim = dim // num_heads
171
+ self.proj_dropout = proj_dropout
172
+ self.norm_eps = norm_eps
173
+
174
+ # layers
175
+ gain = 1.0 / math.sqrt(dim)
176
+ self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
177
+ self.to_q = nn.Linear(dim, dim)
178
+ self.to_kv = nn.Linear(dim, dim * 2)
179
+ self.proj = nn.Linear(dim, dim)
180
+ self.norm = LayerNorm(dim, eps=norm_eps)
181
+ self.mlp = nn.Sequential(
182
+ nn.Linear(dim, int(dim * mlp_ratio)),
183
+ QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
184
+ nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
185
+
186
+ def forward(self, x):
187
+ """
188
+ x: [B, L, C].
189
+ """
190
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
191
+
192
+ # compute query, key, value
193
+ q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
194
+ k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
195
+
196
+ # compute attention
197
+ x = flash_attention(q, k, v, version=2)
198
+ x = x.reshape(b, 1, c)
199
+
200
+ # output
201
+ x = self.proj(x)
202
+ x = F.dropout(x, self.proj_dropout, self.training)
203
+
204
+ # mlp
205
+ x = x + self.mlp(self.norm(x))
206
+ return x[:, 0]
207
+
208
+
209
+ class VisionTransformer(nn.Module):
210
+
211
+ def __init__(self,
212
+ image_size=224,
213
+ patch_size=16,
214
+ dim=768,
215
+ mlp_ratio=4,
216
+ out_dim=512,
217
+ num_heads=12,
218
+ num_layers=12,
219
+ pool_type='token',
220
+ pre_norm=True,
221
+ post_norm=False,
222
+ activation='quick_gelu',
223
+ attn_dropout=0.0,
224
+ proj_dropout=0.0,
225
+ embedding_dropout=0.0,
226
+ norm_eps=1e-5):
227
+ if image_size % patch_size != 0:
228
+ print(
229
+ '[WARNING] image_size is not divisible by patch_size',
230
+ flush=True)
231
+ assert pool_type in ('token', 'token_fc', 'attn_pool')
232
+ out_dim = out_dim or dim
233
+ super().__init__()
234
+ self.image_size = image_size
235
+ self.patch_size = patch_size
236
+ self.num_patches = (image_size // patch_size)**2
237
+ self.dim = dim
238
+ self.mlp_ratio = mlp_ratio
239
+ self.out_dim = out_dim
240
+ self.num_heads = num_heads
241
+ self.num_layers = num_layers
242
+ self.pool_type = pool_type
243
+ self.post_norm = post_norm
244
+ self.norm_eps = norm_eps
245
+
246
+ # embeddings
247
+ gain = 1.0 / math.sqrt(dim)
248
+ self.patch_embedding = nn.Conv2d(
249
+ 3,
250
+ dim,
251
+ kernel_size=patch_size,
252
+ stride=patch_size,
253
+ bias=not pre_norm)
254
+ if pool_type in ('token', 'token_fc'):
255
+ self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
256
+ self.pos_embedding = nn.Parameter(gain * torch.randn(
257
+ 1, self.num_patches +
258
+ (1 if pool_type in ('token', 'token_fc') else 0), dim))
259
+ self.dropout = nn.Dropout(embedding_dropout)
260
+
261
+ # transformer
262
+ self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
263
+ self.transformer = nn.Sequential(*[
264
+ AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False,
265
+ activation, attn_dropout, proj_dropout, norm_eps)
266
+ for _ in range(num_layers)
267
+ ])
268
+ self.post_norm = LayerNorm(dim, eps=norm_eps)
269
+
270
+ # head
271
+ if pool_type == 'token':
272
+ self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
273
+ elif pool_type == 'token_fc':
274
+ self.head = nn.Linear(dim, out_dim)
275
+ elif pool_type == 'attn_pool':
276
+ self.head = AttentionPool(dim, mlp_ratio, num_heads, activation,
277
+ proj_dropout, norm_eps)
278
+
279
+ def forward(self, x, interpolation=False, use_31_block=False):
280
+ b = x.size(0)
281
+
282
+ # embeddings
283
+ x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
284
+ if self.pool_type in ('token', 'token_fc'):
285
+ x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1)
286
+ if interpolation:
287
+ e = pos_interpolate(self.pos_embedding, x.size(1))
288
+ else:
289
+ e = self.pos_embedding
290
+ x = self.dropout(x + e)
291
+ if self.pre_norm is not None:
292
+ x = self.pre_norm(x)
293
+
294
+ # transformer
295
+ if use_31_block:
296
+ x = self.transformer[:-1](x)
297
+ return x
298
+ else:
299
+ x = self.transformer(x)
300
+ return x
301
+
302
+
303
+ class XLMRobertaWithHead(XLMRoberta):
304
+
305
+ def __init__(self, **kwargs):
306
+ self.out_dim = kwargs.pop('out_dim')
307
+ super().__init__(**kwargs)
308
+
309
+ # head
310
+ mid_dim = (self.dim + self.out_dim) // 2
311
+ self.head = nn.Sequential(
312
+ nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(),
313
+ nn.Linear(mid_dim, self.out_dim, bias=False))
314
+
315
+ def forward(self, ids):
316
+ # xlm-roberta
317
+ x = super().forward(ids)
318
+
319
+ # average pooling
320
+ mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
321
+ x = (x * mask).sum(dim=1) / mask.sum(dim=1)
322
+
323
+ # head
324
+ x = self.head(x)
325
+ return x
326
+
327
+
328
+ class XLMRobertaCLIP(nn.Module):
329
+
330
+ def __init__(self,
331
+ embed_dim=1024,
332
+ image_size=224,
333
+ patch_size=14,
334
+ vision_dim=1280,
335
+ vision_mlp_ratio=4,
336
+ vision_heads=16,
337
+ vision_layers=32,
338
+ vision_pool='token',
339
+ vision_pre_norm=True,
340
+ vision_post_norm=False,
341
+ activation='gelu',
342
+ vocab_size=250002,
343
+ max_text_len=514,
344
+ type_size=1,
345
+ pad_id=1,
346
+ text_dim=1024,
347
+ text_heads=16,
348
+ text_layers=24,
349
+ text_post_norm=True,
350
+ text_dropout=0.1,
351
+ attn_dropout=0.0,
352
+ proj_dropout=0.0,
353
+ embedding_dropout=0.0,
354
+ norm_eps=1e-5):
355
+ super().__init__()
356
+ self.embed_dim = embed_dim
357
+ self.image_size = image_size
358
+ self.patch_size = patch_size
359
+ self.vision_dim = vision_dim
360
+ self.vision_mlp_ratio = vision_mlp_ratio
361
+ self.vision_heads = vision_heads
362
+ self.vision_layers = vision_layers
363
+ self.vision_pre_norm = vision_pre_norm
364
+ self.vision_post_norm = vision_post_norm
365
+ self.activation = activation
366
+ self.vocab_size = vocab_size
367
+ self.max_text_len = max_text_len
368
+ self.type_size = type_size
369
+ self.pad_id = pad_id
370
+ self.text_dim = text_dim
371
+ self.text_heads = text_heads
372
+ self.text_layers = text_layers
373
+ self.text_post_norm = text_post_norm
374
+ self.norm_eps = norm_eps
375
+
376
+ # models
377
+ self.visual = VisionTransformer(
378
+ image_size=image_size,
379
+ patch_size=patch_size,
380
+ dim=vision_dim,
381
+ mlp_ratio=vision_mlp_ratio,
382
+ out_dim=embed_dim,
383
+ num_heads=vision_heads,
384
+ num_layers=vision_layers,
385
+ pool_type=vision_pool,
386
+ pre_norm=vision_pre_norm,
387
+ post_norm=vision_post_norm,
388
+ activation=activation,
389
+ attn_dropout=attn_dropout,
390
+ proj_dropout=proj_dropout,
391
+ embedding_dropout=embedding_dropout,
392
+ norm_eps=norm_eps)
393
+ self.textual = XLMRobertaWithHead(
394
+ vocab_size=vocab_size,
395
+ max_seq_len=max_text_len,
396
+ type_size=type_size,
397
+ pad_id=pad_id,
398
+ dim=text_dim,
399
+ out_dim=embed_dim,
400
+ num_heads=text_heads,
401
+ num_layers=text_layers,
402
+ post_norm=text_post_norm,
403
+ dropout=text_dropout)
404
+ self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
405
+
406
+ def forward(self, imgs, txt_ids):
407
+ """
408
+ imgs: [B, 3, H, W] of torch.float32.
409
+ - mean: [0.48145466, 0.4578275, 0.40821073]
410
+ - std: [0.26862954, 0.26130258, 0.27577711]
411
+ txt_ids: [B, L] of torch.long.
412
+ Encoded by data.CLIPTokenizer.
413
+ """
414
+ xi = self.visual(imgs)
415
+ xt = self.textual(txt_ids)
416
+ return xi, xt
417
+
418
+ def param_groups(self):
419
+ groups = [{
420
+ 'params': [
421
+ p for n, p in self.named_parameters()
422
+ if 'norm' in n or n.endswith('bias')
423
+ ],
424
+ 'weight_decay': 0.0
425
+ }, {
426
+ 'params': [
427
+ p for n, p in self.named_parameters()
428
+ if not ('norm' in n or n.endswith('bias'))
429
+ ]
430
+ }]
431
+ return groups
432
+
433
+
434
+ def _clip(pretrained=False,
435
+ pretrained_name=None,
436
+ model_cls=XLMRobertaCLIP,
437
+ return_transforms=False,
438
+ return_tokenizer=False,
439
+ tokenizer_padding='eos',
440
+ dtype=torch.float32,
441
+ device='cpu',
442
+ **kwargs):
443
+ # init a model on device
444
+ with torch.device(device):
445
+ model = model_cls(**kwargs)
446
+
447
+ # set device
448
+ model = model.to(dtype=dtype, device=device)
449
+ output = (model,)
450
+
451
+ # init transforms
452
+ if return_transforms:
453
+ # mean and std
454
+ if 'siglip' in pretrained_name.lower():
455
+ mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
456
+ else:
457
+ mean = [0.48145466, 0.4578275, 0.40821073]
458
+ std = [0.26862954, 0.26130258, 0.27577711]
459
+
460
+ # transforms
461
+ transforms = T.Compose([
462
+ T.Resize((model.image_size, model.image_size),
463
+ interpolation=T.InterpolationMode.BICUBIC),
464
+ T.ToTensor(),
465
+ T.Normalize(mean=mean, std=std)
466
+ ])
467
+ output += (transforms,)
468
+ return output[0] if len(output) == 1 else output
469
+
470
+
471
+ def clip_xlm_roberta_vit_h_14(
472
+ pretrained=False,
473
+ pretrained_name='open-clip-xlm-roberta-large-vit-huge-14',
474
+ **kwargs):
475
+ cfg = dict(
476
+ embed_dim=1024,
477
+ image_size=224,
478
+ patch_size=14,
479
+ vision_dim=1280,
480
+ vision_mlp_ratio=4,
481
+ vision_heads=16,
482
+ vision_layers=32,
483
+ vision_pool='token',
484
+ activation='gelu',
485
+ vocab_size=250002,
486
+ max_text_len=514,
487
+ type_size=1,
488
+ pad_id=1,
489
+ text_dim=1024,
490
+ text_heads=16,
491
+ text_layers=24,
492
+ text_post_norm=True,
493
+ text_dropout=0.1,
494
+ attn_dropout=0.0,
495
+ proj_dropout=0.0,
496
+ embedding_dropout=0.0)
497
+ cfg.update(**kwargs)
498
+ return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)
499
+
500
+
501
+ class CLIPModel:
502
+
503
+ def __init__(self, dtype, device, checkpoint_path, tokenizer_path):
504
+ self.dtype = dtype
505
+ self.device = device
506
+ self.checkpoint_path = checkpoint_path
507
+ self.tokenizer_path = tokenizer_path
508
+
509
+ # init model
510
+ self.model, self.transforms = clip_xlm_roberta_vit_h_14(
511
+ pretrained=False,
512
+ return_transforms=True,
513
+ return_tokenizer=False,
514
+ dtype=dtype,
515
+ device=device)
516
+ self.model = self.model.eval().requires_grad_(False)
517
+ logging.info(f'loading {checkpoint_path}')
518
+ self.model.load_state_dict(
519
+ torch.load(checkpoint_path, map_location='cpu'))
520
+
521
+ # init tokenizer
522
+ self.tokenizer = HuggingfaceTokenizer(
523
+ name=tokenizer_path,
524
+ seq_len=self.model.max_text_len - 2,
525
+ clean='whitespace')
526
+
527
+ def visual(self, videos):
528
+ # preprocess
529
+ size = (self.model.image_size,) * 2
530
+ videos = torch.cat([
531
+ F.interpolate(
532
+ u.transpose(0, 1),
533
+ size=size,
534
+ mode='bicubic',
535
+ align_corners=False) for u in videos
536
+ ])
537
+ videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
538
+
539
+ # forward
540
+ with torch.amp.autocast('cuda', dtype=self.dtype):
541
+ out = self.model.visual(videos, use_31_block=True)
542
+ return out
humo/models/wan_modules/model.py ADDED
@@ -0,0 +1,619 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import math
3
+
4
+ import torch
5
+ import torch.cuda.amp as amp
6
+ import torch.nn as nn
7
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
8
+ from diffusers.models.modeling_utils import ModelMixin
9
+
10
+ from .attention import flash_attention
11
+
12
+ __all__ = ['WanModel']
13
+
14
+
15
+ def sinusoidal_embedding_1d(dim, position):
16
+ # preprocess
17
+ assert dim % 2 == 0
18
+ half = dim // 2
19
+ position = position.type(torch.float64)
20
+
21
+ # calculation
22
+ sinusoid = torch.outer(
23
+ position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
24
+ x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
25
+ return x
26
+
27
+
28
+ @torch.amp.autocast("cuda", enabled=False)
29
+ def rope_params(max_seq_len, dim, theta=10000):
30
+ assert dim % 2 == 0
31
+ freqs = torch.outer(
32
+ torch.arange(max_seq_len),
33
+ 1.0 / torch.pow(theta,
34
+ torch.arange(0, dim, 2).to(torch.float64).div(dim)))
35
+ freqs = torch.polar(torch.ones_like(freqs), freqs)
36
+ return freqs
37
+
38
+
39
+ @torch.amp.autocast("cuda", enabled=False)
40
+ def rope_apply(x, grid_sizes, freqs):
41
+ n, c = x.size(2), x.size(3) // 2
42
+
43
+ # split freqs
44
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
45
+
46
+ # loop over samples
47
+ output = []
48
+ for i, (f, h, w) in enumerate(grid_sizes.tolist()):
49
+ seq_len = f * h * w
50
+
51
+ # precompute multipliers
52
+ x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(
53
+ seq_len, n, -1, 2))
54
+ freqs_i = torch.cat([
55
+ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
56
+ freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
57
+ freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
58
+ ],
59
+ dim=-1).reshape(seq_len, 1, -1)
60
+
61
+ # apply rotary embedding
62
+ x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
63
+ x_i = torch.cat([x_i, x[i, seq_len:]])
64
+
65
+ # append to collection
66
+ output.append(x_i)
67
+ return torch.stack(output).float()
68
+
69
+
70
+ class WanRMSNorm(nn.Module):
71
+
72
+ def __init__(self, dim, eps=1e-5):
73
+ super().__init__()
74
+ self.dim = dim
75
+ self.eps = eps
76
+ self.weight = nn.Parameter(torch.ones(dim))
77
+
78
+ def forward(self, x):
79
+ r"""
80
+ Args:
81
+ x(Tensor): Shape [B, L, C]
82
+ """
83
+ return self._norm(x.float()).type_as(x) * self.weight
84
+
85
+ def _norm(self, x):
86
+ return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
87
+
88
+
89
+ class WanLayerNorm(nn.LayerNorm):
90
+
91
+ def __init__(self, dim, eps=1e-6, elementwise_affine=False):
92
+ super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
93
+
94
+ def forward(self, x):
95
+ r"""
96
+ Args:
97
+ x(Tensor): Shape [B, L, C]
98
+ """
99
+ return super().forward(x.float()).type_as(x)
100
+
101
+
102
+ class WanSelfAttention(nn.Module):
103
+
104
+ def __init__(self,
105
+ dim,
106
+ num_heads,
107
+ window_size=(-1, -1),
108
+ qk_norm=True,
109
+ eps=1e-6):
110
+ assert dim % num_heads == 0
111
+ super().__init__()
112
+ self.dim = dim
113
+ self.num_heads = num_heads
114
+ self.head_dim = dim // num_heads
115
+ self.window_size = window_size
116
+ self.qk_norm = qk_norm
117
+ self.eps = eps
118
+
119
+ # layers
120
+ self.q = nn.Linear(dim, dim)
121
+ self.k = nn.Linear(dim, dim)
122
+ self.v = nn.Linear(dim, dim)
123
+ self.o = nn.Linear(dim, dim)
124
+ self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
125
+ self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
126
+
127
+ def forward(self, x, seq_lens, grid_sizes, freqs):
128
+ r"""
129
+ Args:
130
+ x(Tensor): Shape [B, L, num_heads, C / num_heads]
131
+ seq_lens(Tensor): Shape [B]
132
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
133
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
134
+ """
135
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
136
+
137
+ # query, key, value function
138
+ def qkv_fn(x):
139
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
140
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
141
+ v = self.v(x).view(b, s, n, d)
142
+ return q, k, v
143
+
144
+ q, k, v = qkv_fn(x)
145
+
146
+ x = flash_attention(
147
+ q=rope_apply(q, grid_sizes, freqs),
148
+ k=rope_apply(k, grid_sizes, freqs),
149
+ v=v,
150
+ k_lens=seq_lens,
151
+ window_size=self.window_size)
152
+
153
+ # output
154
+ x = x.flatten(2)
155
+ x = self.o(x)
156
+ return x
157
+
158
+
159
+ class WanT2VCrossAttention(WanSelfAttention):
160
+
161
+ def forward(self, x, context, context_lens):
162
+ r"""
163
+ Args:
164
+ x(Tensor): Shape [B, L1, C]
165
+ context(Tensor): Shape [B, L2, C]
166
+ context_lens(Tensor): Shape [B]
167
+ """
168
+ b, n, d = x.size(0), self.num_heads, self.head_dim
169
+
170
+ # compute query, key, value
171
+ q = self.norm_q(self.q(x)).view(b, -1, n, d)
172
+ k = self.norm_k(self.k(context)).view(b, -1, n, d)
173
+ v = self.v(context).view(b, -1, n, d)
174
+
175
+ # compute attention
176
+ x = flash_attention(q, k, v, k_lens=context_lens)
177
+
178
+ # output
179
+ x = x.flatten(2)
180
+ x = self.o(x)
181
+ return x
182
+
183
+
184
+ class WanI2VCrossAttention(WanSelfAttention):
185
+
186
+ def __init__(self,
187
+ dim,
188
+ num_heads,
189
+ window_size=(-1, -1),
190
+ qk_norm=True,
191
+ eps=1e-6):
192
+ super().__init__(dim, num_heads, window_size, qk_norm, eps)
193
+
194
+ self.k_img = nn.Linear(dim, dim)
195
+ self.v_img = nn.Linear(dim, dim)
196
+ # self.alpha = nn.Parameter(torch.zeros((1, )))
197
+ self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
198
+
199
+ def forward(self, x, context, context_lens):
200
+ r"""
201
+ Args:
202
+ x(Tensor): Shape [B, L1, C]
203
+ context(Tensor): Shape [B, L2, C]
204
+ context_lens(Tensor): Shape [B]
205
+ """
206
+ context_img = context[:, :257]
207
+ context = context[:, 257:]
208
+ b, n, d = x.size(0), self.num_heads, self.head_dim
209
+
210
+ # compute query, key, value
211
+ q = self.norm_q(self.q(x)).view(b, -1, n, d)
212
+ k = self.norm_k(self.k(context)).view(b, -1, n, d)
213
+ v = self.v(context).view(b, -1, n, d)
214
+ k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
215
+ v_img = self.v_img(context_img).view(b, -1, n, d)
216
+ img_x = flash_attention(q, k_img, v_img, k_lens=None)
217
+ # compute attention
218
+ x = flash_attention(q, k, v, k_lens=context_lens)
219
+
220
+ # output
221
+ x = x.flatten(2)
222
+ img_x = img_x.flatten(2)
223
+ x = x + img_x
224
+ x = self.o(x)
225
+ return x
226
+
227
+
228
+ WAN_CROSSATTENTION_CLASSES = {
229
+ 't2v_cross_attn': WanT2VCrossAttention,
230
+ 'i2v_cross_attn': WanI2VCrossAttention,
231
+ }
232
+
233
+
234
+ class WanAttentionBlock(nn.Module):
235
+
236
+ def __init__(self,
237
+ cross_attn_type,
238
+ dim,
239
+ ffn_dim,
240
+ num_heads,
241
+ window_size=(-1, -1),
242
+ qk_norm=True,
243
+ cross_attn_norm=False,
244
+ eps=1e-6):
245
+ super().__init__()
246
+ self.dim = dim
247
+ self.ffn_dim = ffn_dim
248
+ self.num_heads = num_heads
249
+ self.window_size = window_size
250
+ self.qk_norm = qk_norm
251
+ self.cross_attn_norm = cross_attn_norm
252
+ self.eps = eps
253
+
254
+ # layers
255
+ self.norm1 = WanLayerNorm(dim, eps)
256
+ self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
257
+ eps)
258
+ self.norm3 = WanLayerNorm(
259
+ dim, eps,
260
+ elementwise_affine=True) if cross_attn_norm else nn.Identity()
261
+ self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim,
262
+ num_heads,
263
+ (-1, -1),
264
+ qk_norm,
265
+ eps)
266
+ self.norm2 = WanLayerNorm(dim, eps)
267
+ self.ffn = nn.Sequential(
268
+ nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
269
+ nn.Linear(ffn_dim, dim))
270
+
271
+ # modulation
272
+ self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
273
+
274
+ def forward(
275
+ self,
276
+ x,
277
+ e,
278
+ seq_lens,
279
+ grid_sizes,
280
+ freqs,
281
+ context,
282
+ context_lens,
283
+ ):
284
+ r"""
285
+ Args:
286
+ x(Tensor): Shape [B, L, C]
287
+ e(Tensor): Shape [B, 6, C]
288
+ seq_lens(Tensor): Shape [B], length of each sequence in batch
289
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
290
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
291
+ """
292
+ assert e.dtype == torch.float32
293
+ with torch.amp.autocast('cuda', dtype=torch.float32):
294
+ e = (self.modulation + e).chunk(6, dim=1)
295
+ assert e[0].dtype == torch.float32
296
+
297
+ # self-attention
298
+ y = self.self_attn(
299
+ self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes,
300
+ freqs)
301
+ with torch.amp.autocast('cuda', dtype=torch.float32):
302
+ x = x + y * e[2]
303
+
304
+ # cross-attention & ffn function
305
+ def cross_attn_ffn(x, context, context_lens, e):
306
+ x = x + self.cross_attn(self.norm3(x), context, context_lens)
307
+ y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
308
+ with torch.amp.autocast('cuda', dtype=torch.float32):
309
+ x = x + y * e[5]
310
+ return x
311
+
312
+ x = cross_attn_ffn(x, context, context_lens, e)
313
+ return x
314
+
315
+
316
+ class Head(nn.Module):
317
+
318
+ def __init__(self, dim, out_dim, patch_size, eps=1e-6):
319
+ super().__init__()
320
+ self.dim = dim
321
+ self.out_dim = out_dim
322
+ self.patch_size = patch_size
323
+ self.eps = eps
324
+
325
+ # layers
326
+ out_dim = math.prod(patch_size) * out_dim
327
+ self.norm = WanLayerNorm(dim, eps)
328
+ self.head = nn.Linear(dim, out_dim)
329
+
330
+ # modulation
331
+ self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
332
+
333
+ def forward(self, x, e):
334
+ r"""
335
+ Args:
336
+ x(Tensor): Shape [B, L1, C]
337
+ e(Tensor): Shape [B, C]
338
+ """
339
+ assert e.dtype == torch.float32
340
+ with torch.amp.autocast('cuda', dtype=torch.float32):
341
+ e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
342
+ x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
343
+ return x
344
+
345
+
346
+ class MLPProj(torch.nn.Module):
347
+
348
+ def __init__(self, in_dim, out_dim):
349
+ super().__init__()
350
+
351
+ self.proj = torch.nn.Sequential(
352
+ torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim),
353
+ torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim),
354
+ torch.nn.LayerNorm(out_dim))
355
+
356
+ def forward(self, image_embeds):
357
+ clip_extra_context_tokens = self.proj(image_embeds)
358
+ return clip_extra_context_tokens
359
+
360
+
361
+ class WanModel(ModelMixin, ConfigMixin):
362
+ r"""
363
+ Wan diffusion backbone supporting both text-to-video and image-to-video.
364
+ """
365
+
366
+ ignore_for_config = [
367
+ 'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size'
368
+ ]
369
+ _no_split_modules = ['WanAttentionBlock']
370
+
371
+ @register_to_config
372
+ def __init__(self,
373
+ model_type='t2v',
374
+ patch_size=(1, 2, 2),
375
+ text_len=512,
376
+ in_dim=16,
377
+ dim=5120,
378
+ ffn_dim=13824,
379
+ freq_dim=256,
380
+ text_dim=4096,
381
+ out_dim=16,
382
+ num_heads=40,
383
+ num_layers=40,
384
+ window_size=(-1, -1),
385
+ qk_norm=True,
386
+ cross_attn_norm=True,
387
+ eps=1e-6):
388
+ r"""
389
+ Initialize the diffusion model backbone.
390
+
391
+ Args:
392
+ model_type (`str`, *optional*, defaults to 't2v'):
393
+ Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
394
+ patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
395
+ 3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
396
+ text_len (`int`, *optional*, defaults to 512):
397
+ Fixed length for text embeddings
398
+ in_dim (`int`, *optional*, defaults to 16):
399
+ Input video channels (C_in)
400
+ dim (`int`, *optional*, defaults to 2048):
401
+ Hidden dimension of the transformer
402
+ ffn_dim (`int`, *optional*, defaults to 8192):
403
+ Intermediate dimension in feed-forward network
404
+ freq_dim (`int`, *optional*, defaults to 256):
405
+ Dimension for sinusoidal time embeddings
406
+ text_dim (`int`, *optional*, defaults to 4096):
407
+ Input dimension for text embeddings
408
+ out_dim (`int`, *optional*, defaults to 16):
409
+ Output video channels (C_out)
410
+ num_heads (`int`, *optional*, defaults to 16):
411
+ Number of attention heads
412
+ num_layers (`int`, *optional*, defaults to 32):
413
+ Number of transformer blocks
414
+ window_size (`tuple`, *optional*, defaults to (-1, -1)):
415
+ Window size for local attention (-1 indicates global attention)
416
+ qk_norm (`bool`, *optional*, defaults to True):
417
+ Enable query/key normalization
418
+ cross_attn_norm (`bool`, *optional*, defaults to False):
419
+ Enable cross-attention normalization
420
+ eps (`float`, *optional*, defaults to 1e-6):
421
+ Epsilon value for normalization layers
422
+ """
423
+
424
+ super().__init__()
425
+
426
+ assert model_type in ['t2v', 'i2v']
427
+ self.model_type = model_type
428
+
429
+ self.patch_size = patch_size
430
+ self.text_len = text_len
431
+ self.in_dim = in_dim
432
+ self.dim = dim
433
+ self.ffn_dim = ffn_dim
434
+ self.freq_dim = freq_dim
435
+ self.text_dim = text_dim
436
+ self.out_dim = out_dim
437
+ self.num_heads = num_heads
438
+ self.num_layers = num_layers
439
+ self.window_size = window_size
440
+ self.qk_norm = qk_norm
441
+ self.cross_attn_norm = cross_attn_norm
442
+ self.eps = eps
443
+
444
+ # embeddings
445
+ self.patch_embedding = nn.Conv3d(
446
+ in_dim, dim, kernel_size=patch_size, stride=patch_size)
447
+ self.text_embedding = nn.Sequential(
448
+ nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
449
+ nn.Linear(dim, dim))
450
+
451
+ self.time_embedding = nn.Sequential(
452
+ nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
453
+ self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
454
+
455
+ # blocks
456
+ cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
457
+ self.blocks = nn.ModuleList([
458
+ WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
459
+ window_size, qk_norm, cross_attn_norm, eps)
460
+ for _ in range(num_layers)
461
+ ])
462
+
463
+ # head
464
+ self.head = Head(dim, out_dim, patch_size, eps)
465
+
466
+ # buffers (don't use register_buffer otherwise dtype will be changed in to())
467
+ assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
468
+ d = dim // num_heads
469
+ self.freqs = torch.cat([
470
+ rope_params(1024, d - 4 * (d // 6)),
471
+ rope_params(1024, 2 * (d // 6)),
472
+ rope_params(1024, 2 * (d // 6))
473
+ ],
474
+ dim=1)
475
+
476
+ if model_type == 'i2v':
477
+ self.img_emb = MLPProj(1280, dim)
478
+
479
+ # initialize weights
480
+ self.init_weights()
481
+
482
+ def forward(
483
+ self,
484
+ x,
485
+ t,
486
+ context,
487
+ seq_len,
488
+ clip_fea=None,
489
+ y=None,
490
+ ):
491
+ r"""
492
+ Forward pass through the diffusion model
493
+
494
+ Args:
495
+ x (List[Tensor]):
496
+ List of input video tensors, each with shape [C_in, F, H, W]
497
+ t (Tensor):
498
+ Diffusion timesteps tensor of shape [B]
499
+ context (List[Tensor]):
500
+ List of text embeddings each with shape [L, C]
501
+ seq_len (`int`):
502
+ Maximum sequence length for positional encoding
503
+ clip_fea (Tensor, *optional*):
504
+ CLIP image features for image-to-video mode
505
+ y (List[Tensor], *optional*):
506
+ Conditional video inputs for image-to-video mode, same shape as x
507
+
508
+ Returns:
509
+ List[Tensor]:
510
+ List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
511
+ """
512
+ if self.model_type == 'i2v':
513
+ assert clip_fea is not None and y is not None
514
+ # params
515
+ device = self.patch_embedding.weight.device
516
+ freqs = self.freqs.to(device)
517
+
518
+ if y is not None:
519
+ x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
520
+
521
+ # embeddings
522
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
523
+ grid_sizes = torch.stack(
524
+ [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
525
+ x = [u.flatten(2).transpose(1, 2) for u in x]
526
+ seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
527
+ assert seq_lens.max() <= seq_len
528
+ x = torch.cat([
529
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
530
+ dim=1) for u in x
531
+ ])
532
+
533
+ # time embeddings
534
+ with torch.amp.autocast('cuda', dtype=torch.float32):
535
+ e = self.time_embedding(
536
+ sinusoidal_embedding_1d(self.freq_dim, t).float())
537
+ e0 = self.time_projection(e).unflatten(1, (6, self.dim))
538
+ assert e.dtype == torch.float32 and e0.dtype == torch.float32
539
+
540
+ # context
541
+ context_lens = None
542
+ context = self.text_embedding(
543
+ torch.stack([
544
+ torch.cat(
545
+ [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
546
+ for u in context
547
+ ]))
548
+
549
+ if clip_fea is not None:
550
+ context_clip = self.img_emb(clip_fea) # bs x 257 x dim
551
+ context = torch.concat([context_clip, context], dim=1)
552
+
553
+ # arguments
554
+ kwargs = dict(
555
+ e=e0,
556
+ seq_lens=seq_lens,
557
+ grid_sizes=grid_sizes,
558
+ freqs=freqs,
559
+ context=context,
560
+ context_lens=context_lens)
561
+
562
+ for block in self.blocks:
563
+ x = block(x, **kwargs)
564
+
565
+ # head
566
+ x = self.head(x, e)
567
+
568
+ # unpatchify
569
+ x = self.unpatchify(x, grid_sizes)
570
+ return [u.float() for u in x]
571
+
572
+ def unpatchify(self, x, grid_sizes):
573
+ r"""
574
+ Reconstruct video tensors from patch embeddings.
575
+
576
+ Args:
577
+ x (List[Tensor]):
578
+ List of patchified features, each with shape [L, C_out * prod(patch_size)]
579
+ grid_sizes (Tensor):
580
+ Original spatial-temporal grid dimensions before patching,
581
+ shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
582
+
583
+ Returns:
584
+ List[Tensor]:
585
+ Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
586
+ """
587
+
588
+ c = self.out_dim
589
+ out = []
590
+ for u, v in zip(x, grid_sizes.tolist()):
591
+ u = u[:math.prod(v)].view(*v, *self.patch_size, c)
592
+ u = torch.einsum('fhwpqrc->cfphqwr', u)
593
+ u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
594
+ out.append(u)
595
+ return out
596
+
597
+ def init_weights(self):
598
+ r"""
599
+ Initialize model parameters using Xavier initialization.
600
+ """
601
+
602
+ # basic init
603
+ for m in self.modules():
604
+ if isinstance(m, nn.Linear):
605
+ nn.init.xavier_uniform_(m.weight)
606
+ if m.bias is not None:
607
+ nn.init.zeros_(m.bias)
608
+
609
+ # init embeddings
610
+ nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
611
+ for m in self.text_embedding.modules():
612
+ if isinstance(m, nn.Linear):
613
+ nn.init.normal_(m.weight, std=.02)
614
+ for m in self.time_embedding.modules():
615
+ if isinstance(m, nn.Linear):
616
+ nn.init.normal_(m.weight, std=.02)
617
+
618
+ # init output layer
619
+ nn.init.zeros_(self.head.head.weight)
humo/models/wan_modules/model_humo.py ADDED
@@ -0,0 +1,803 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ from common.distributed import get_device
5
+ from models.audio.audio_proj import AudioProjModel
6
+
7
+ import torch.cuda.amp as amp
8
+ import math
9
+ from humo.models.wan_modules.attention import flash_attention
10
+ from common.distributed.advanced import is_unified_parallel_initialized
11
+
12
+ import types
13
+
14
+ def sinusoidal_embedding_1d(dim, position):
15
+ # preprocess
16
+ assert dim % 2 == 0
17
+ half = dim // 2
18
+ position = position.type(torch.float64)
19
+
20
+ # calculation
21
+ sinusoid = torch.outer(
22
+ position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
23
+ x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
24
+ return x
25
+
26
+
27
+ @amp.autocast(enabled=False)
28
+ def rope_params(max_seq_len, dim, theta=10000):
29
+ assert dim % 2 == 0
30
+ freqs = torch.outer(
31
+ torch.arange(max_seq_len),
32
+ 1.0 / torch.pow(theta,
33
+ torch.arange(0, dim, 2).to(torch.float32).div(dim)))
34
+ freqs = torch.polar(torch.ones_like(freqs), freqs)
35
+ return freqs
36
+
37
+
38
+ @amp.autocast(enabled=False)
39
+ def rope_apply(x, grid_sizes, freqs):
40
+ n, c = x.size(2), x.size(3) // 2
41
+
42
+ # split freqs
43
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
44
+
45
+ # loop over samples
46
+ output = []
47
+ for i, (f, h, w) in enumerate(grid_sizes.tolist()):
48
+ seq_len = f * h * w
49
+
50
+ # precompute multipliers
51
+ x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float32).reshape(
52
+ seq_len, n, -1, 2))
53
+ freqs_i = torch.cat([
54
+ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
55
+ freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
56
+ freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
57
+ ],
58
+ dim=-1).reshape(seq_len, 1, -1)
59
+
60
+ # apply rotary embedding
61
+ x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
62
+ x_i = torch.cat([x_i, x[i, seq_len:]])
63
+
64
+ # append to collection
65
+ output.append(x_i)
66
+ return torch.stack(output).float()
67
+
68
+
69
+ class WanRMSNorm(nn.Module):
70
+
71
+ def __init__(self, dim, eps=1e-5):
72
+ super().__init__()
73
+ self.dim = dim
74
+ self.eps = eps
75
+ self.weight = nn.Parameter(torch.ones(dim))
76
+
77
+ def forward(self, x):
78
+ r"""
79
+ Args:
80
+ x(Tensor): Shape [B, L, C]
81
+ """
82
+ return self._norm(x.float()).type_as(x) * self.weight
83
+
84
+ def _norm(self, x):
85
+ return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
86
+
87
+
88
+ class WanLayerNorm(nn.LayerNorm):
89
+
90
+ def __init__(self, dim, eps=1e-6, elementwise_affine=False):
91
+ super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
92
+
93
+ def forward(self, x):
94
+ r"""
95
+ Args:
96
+ x(Tensor): Shape [B, L, C]
97
+ """
98
+ return super().forward(x.float()).type_as(x)
99
+
100
+
101
+ class WanSelfAttention(nn.Module):
102
+
103
+ def __init__(self,
104
+ dim,
105
+ num_heads,
106
+ window_size=(-1, -1),
107
+ qk_norm=True,
108
+ eps=1e-6):
109
+ assert dim % num_heads == 0
110
+ super().__init__()
111
+ self.dim = dim
112
+ self.num_heads = num_heads
113
+ self.head_dim = dim // num_heads
114
+ self.window_size = window_size
115
+ self.qk_norm = qk_norm
116
+ self.eps = eps
117
+
118
+ # layers
119
+ self.q = nn.Linear(dim, dim)
120
+ self.k = nn.Linear(dim, dim)
121
+ self.v = nn.Linear(dim, dim)
122
+ self.o = nn.Linear(dim, dim)
123
+ self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
124
+ self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
125
+
126
+ def forward(self, x, seq_lens, grid_sizes, freqs):
127
+ r"""
128
+ Args:
129
+ x(Tensor): Shape [B, L, num_heads, C / num_heads], torch.Size([1, 9360, 5120])
130
+ seq_lens(Tensor): Shape [B], tensor([9360])
131
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W), tensor([[ 6, 30, 52]])
132
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
133
+ """
134
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
135
+
136
+ # query, key, value function
137
+ def qkv_fn(x):
138
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
139
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
140
+ v = self.v(x).view(b, s, n, d)
141
+ return q, k, v
142
+
143
+ q, k, v = qkv_fn(x)
144
+
145
+ x = flash_attention(
146
+ q=rope_apply(q, grid_sizes, freqs),
147
+ k=rope_apply(k, grid_sizes, freqs),
148
+ v=v,
149
+ k_lens=seq_lens,
150
+ window_size=self.window_size)
151
+
152
+ # output
153
+ x = x.flatten(2)
154
+ x = self.o(x)
155
+ return x
156
+
157
+
158
+ class WanSelfAttentionSepKVDim(nn.Module):
159
+
160
+ def __init__(self,
161
+ kv_dim,
162
+ dim,
163
+ num_heads,
164
+ window_size=(-1, -1),
165
+ qk_norm=True,
166
+ eps=1e-6):
167
+ assert dim % num_heads == 0
168
+ super().__init__()
169
+ self.dim = dim
170
+ self.num_heads = num_heads
171
+ self.head_dim = dim // num_heads
172
+ self.window_size = window_size
173
+ self.qk_norm = qk_norm
174
+ self.eps = eps
175
+
176
+ # layers
177
+ self.q = nn.Linear(dim, dim)
178
+ self.k = nn.Linear(kv_dim, dim)
179
+ self.v = nn.Linear(kv_dim, dim)
180
+ self.o = nn.Linear(dim, dim)
181
+ self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
182
+ self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
183
+
184
+ def forward(self, x, seq_lens, grid_sizes, freqs):
185
+ r"""
186
+ Args:
187
+ x(Tensor): Shape [B, L, num_heads, C / num_heads], torch.Size([1, 9360, 5120])
188
+ seq_lens(Tensor): Shape [B], tensor([9360])
189
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W), tensor([[ 6, 30, 52]])
190
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
191
+ """
192
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
193
+
194
+ # query, key, value function
195
+ def qkv_fn(x):
196
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
197
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
198
+ v = self.v(x).view(b, s, n, d)
199
+ return q, k, v
200
+
201
+ q, k, v = qkv_fn(x)
202
+
203
+ x = flash_attention(
204
+ q=rope_apply(q, grid_sizes, freqs),
205
+ k=rope_apply(k, grid_sizes, freqs),
206
+ v=v,
207
+ k_lens=seq_lens,
208
+ window_size=self.window_size)
209
+
210
+ # output
211
+ x = x.flatten(2)
212
+ x = self.o(x)
213
+ return x
214
+
215
+
216
+
217
+ class WanT2VCrossAttention(WanSelfAttention):
218
+
219
+ def forward(self, x, context, context_lens):
220
+ r"""
221
+ Args:
222
+ x(Tensor): Shape [B, L1, C]
223
+ context(Tensor): Shape [B, L2, C]
224
+ context_lens(Tensor): Shape [B]
225
+ """
226
+ b, n, d = x.size(0), self.num_heads, self.head_dim
227
+
228
+ # compute query, key, value
229
+ q = self.norm_q(self.q(x)).view(b, -1, n, d)
230
+ k = self.norm_k(self.k(context)).view(b, -1, n, d)
231
+ v = self.v(context).view(b, -1, n, d)
232
+
233
+ # compute attention
234
+ x = flash_attention(q, k, v, k_lens=context_lens)
235
+
236
+ # output
237
+ x = x.flatten(2)
238
+ x = self.o(x)
239
+ return x
240
+
241
+
242
+ class WanT2VCrossAttentionGather(WanSelfAttentionSepKVDim):
243
+
244
+ def forward(self, x, context, context_lens, grid_sizes, freqs, audio_seq_len):
245
+ b, n, d = x.size(0), self.num_heads, self.head_dim
246
+
247
+ q = self.norm_q(self.q(x)).view(b, -1, n, d)
248
+ k = self.norm_k(self.k(context)).view(b, -1, n, d)
249
+ v = self.v(context).view(b, -1, n, d)
250
+
251
+ # --- NEW: derive sizes from shapes (SymInts), no int(tensor) casts ---
252
+ Lq = q.shape[1] # total video tokens per sample
253
+ # audio has 16 tokens per frame -> frames = audio_tokens // 16
254
+ frames = (context.shape[1] // 16)
255
+ hlen_wlen = Lq // frames # tokens per frame = H*W
256
+
257
+ # Now reshape using SymInt-derived sizes
258
+ q = q.reshape(-1, hlen_wlen, n, d)
259
+ k = k.reshape(-1, 16, n, d)
260
+ v = v.reshape(-1, 16, n, d)
261
+
262
+ x = flash_attention(q, k, v, k_lens=None)
263
+ x = x.view(b, -1, n, d).flatten(2)
264
+ x = self.o(x)
265
+ return x
266
+
267
+ # def forward(self, x, context, context_lens, grid_sizes, freqs, audio_seq_len):
268
+ # r"""
269
+ # Args:
270
+ # x(Tensor): Shape [B, L1, C] - video tokens
271
+ # context(Tensor): Shape [B, L2, C] - audio tokens with shape [B, frames*16, 1536]
272
+ # context_lens(Tensor): Shape [B] - actually seq_lens from call (video sequence length)
273
+ # grid_sizes(Tensor): Shape [B, 3] - video grid dimensions (F, H, W)
274
+ # freqs(Tensor): RoPE frequencies
275
+ # audio_seq_len(Tensor): Actual audio sequence length (frames * 16)
276
+ # """
277
+ # b, n, d = x.size(0), self.num_heads, self.head_dim
278
+
279
+ # q = self.norm_q(self.q(x)).view(b, -1, n, d)
280
+ # k = self.norm_k(self.k(context)).view(b, -1, n, d)
281
+ # v = self.v(context).view(b, -1, n, d)
282
+
283
+ # # Handle video spatial structure
284
+ # hlen_wlen = int(grid_sizes[0][1] * grid_sizes[0][2])
285
+ # q = q.reshape(-1, hlen_wlen, n, d)
286
+
287
+ # # Handle audio temporal structure (16 tokens per frame)
288
+ # k = k.reshape(-1, 16, n, d)
289
+ # v = v.reshape(-1, 16, n, d)
290
+
291
+ # # Cross-attention
292
+ # x = flash_attention(q, k, v, k_lens=None) # No masking for audio
293
+
294
+ # x = x.view(b, -1, n, d).flatten(2)
295
+ # x = self.o(x)
296
+ # return x
297
+
298
+
299
+ class AudioCrossAttentionWrapper(nn.Module):
300
+ def __init__(self, dim, kv_dim, num_heads, qk_norm=True, eps=1e-6,):
301
+ super().__init__()
302
+
303
+ self.audio_cross_attn = WanT2VCrossAttentionGather(
304
+ kv_dim, dim, num_heads, (-1, -1), qk_norm, eps)
305
+ self.norm1_audio = WanLayerNorm(dim, eps,
306
+ elementwise_affine=True)
307
+
308
+ def forward(self, x, audio, seq_lens, grid_sizes, freqs, audio_seq_len):
309
+ x = x + self.audio_cross_attn(
310
+ self.norm1_audio(x), audio, seq_lens, grid_sizes, freqs, audio_seq_len)
311
+ return x
312
+
313
+
314
+ class WanI2VCrossAttention(WanSelfAttention):
315
+
316
+ def __init__(self,
317
+ dim,
318
+ num_heads,
319
+ window_size=(-1, -1),
320
+ qk_norm=True,
321
+ eps=1e-6):
322
+ super().__init__(dim, num_heads, window_size, qk_norm, eps)
323
+
324
+ def forward(self, x, context, context_lens):
325
+ r"""
326
+ Args:
327
+ x(Tensor): Shape [B, L1, C]
328
+ context(Tensor): Shape [B, L2, C]
329
+ context_lens(Tensor): Shape [B]
330
+ """
331
+ b, n, d = x.size(0), self.num_heads, self.head_dim
332
+
333
+ # compute query, key, value
334
+ q = self.norm_q(self.q(x)).view(b, -1, n, d)
335
+ k = self.norm_k(self.k(context)).view(b, -1, n, d)
336
+ v = self.v(context).view(b, -1, n, d)
337
+ x = flash_attention(q, k, v, k_lens=context_lens)
338
+
339
+ # output
340
+ x = x.flatten(2)
341
+ x = self.o(x)
342
+ return x
343
+
344
+
345
+ WAN_CROSSATTENTION_CLASSES = {
346
+ 't2v_cross_attn': WanT2VCrossAttention,
347
+ 'i2v_cross_attn': WanI2VCrossAttention,
348
+ }
349
+
350
+ class WanAttentionBlock(nn.Module):
351
+
352
+ def __init__(self,
353
+ cross_attn_type,
354
+ dim,
355
+ ffn_dim,
356
+ num_heads,
357
+ window_size=(-1, -1),
358
+ qk_norm=True,
359
+ cross_attn_norm=False,
360
+ eps=1e-6,
361
+ use_audio=True):
362
+ super().__init__()
363
+ self.dim = dim
364
+ self.ffn_dim = ffn_dim
365
+ self.num_heads = num_heads
366
+ self.window_size = window_size
367
+ self.qk_norm = qk_norm
368
+ self.cross_attn_norm = cross_attn_norm
369
+ self.eps = eps
370
+
371
+ # layers
372
+ self.norm1 = WanLayerNorm(dim, eps)
373
+ self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
374
+ eps)
375
+ self.norm3 = WanLayerNorm(
376
+ dim, eps,
377
+ elementwise_affine=True) if cross_attn_norm else nn.Identity()
378
+ self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim,
379
+ num_heads,
380
+ (-1, -1),
381
+ qk_norm,
382
+ eps)
383
+ self.norm2 = WanLayerNorm(dim, eps)
384
+ self.ffn = nn.Sequential(
385
+ nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
386
+ nn.Linear(ffn_dim, dim))
387
+
388
+ # modulation
389
+ self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
390
+
391
+ self.use_audio = use_audio
392
+ if use_audio:
393
+ self.audio_cross_attn_wrapper = AudioCrossAttentionWrapper(dim, 1536, num_heads, qk_norm, eps)
394
+
395
+ def forward(
396
+ self,
397
+ x, # torch.Size([1, 9360, 5120])
398
+ e, # torch.Size([1, 6, 5120])
399
+ seq_lens, # tensor([9360])
400
+ grid_sizes, # tensor([[ 6, 30, 52]])
401
+ freqs, # torch.Size([1024, 64])
402
+ context, # torch.Size([1, 512, 5120])
403
+ context_lens, # None
404
+ audio=None, # None
405
+ audio_seq_len=None,
406
+ ref_num_list=None,
407
+ ):
408
+ r"""
409
+ Args:
410
+ x(Tensor): Shape [B, L, C]
411
+ e(Tensor): Shape [B, L, C]
412
+ audio(Tensor): Shape [B, L, C]
413
+ seq_lens(Tensor): Shape [B], length of each sequence in batch
414
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
415
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
416
+ ref_num_list: 配合seq_lens可以查到reference image在倒数第几个
417
+ """
418
+ assert e.dtype == torch.float32
419
+ with torch.amp.autocast('cuda', dtype=torch.float32):
420
+ e = (self.modulation + e).chunk(6, dim=1)
421
+ assert e[0].dtype == torch.float32
422
+
423
+ # self-attention
424
+ y = self.self_attn(
425
+ self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes,
426
+ freqs)
427
+ with torch.amp.autocast('cuda', dtype=torch.float32):
428
+ x = x + y * e[2]
429
+
430
+ # cross-attention & ffn function
431
+ def cross_attn_ffn(x, context, context_lens, e):
432
+ x = x + self.cross_attn(self.norm3(x), context, context_lens)
433
+
434
+ if self.use_audio:
435
+ x = self.audio_cross_attn_wrapper(x, audio, seq_lens, grid_sizes, freqs, audio_seq_len)
436
+
437
+ y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
438
+ with torch.amp.autocast('cuda', dtype=torch.float32):
439
+ x = x + y * e[5]
440
+ return x
441
+
442
+ x = cross_attn_ffn(x, context, context_lens, e)
443
+
444
+ return x
445
+
446
+
447
+ class Head(nn.Module):
448
+
449
+ def __init__(self, dim, out_dim, patch_size, eps=1e-6):
450
+ super().__init__()
451
+ self.dim = dim
452
+ self.out_dim = out_dim
453
+ self.patch_size = patch_size
454
+ self.eps = eps
455
+
456
+ # layers
457
+ out_dim = math.prod(patch_size) * out_dim
458
+ self.norm = WanLayerNorm(dim, eps)
459
+ self.head = nn.Linear(dim, out_dim)
460
+
461
+ # modulation
462
+ self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
463
+
464
+ def forward(self, x, e):
465
+ r"""
466
+ Args:
467
+ x(Tensor): Shape [B, L1, C]
468
+ e(Tensor): Shape [B, C]
469
+ """
470
+ assert e.dtype == torch.float32
471
+ with torch.amp.autocast('cuda', dtype=torch.float32):
472
+ e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
473
+ x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
474
+ return x
475
+
476
+
477
+ class MLPProj(torch.nn.Module):
478
+
479
+ def __init__(self, in_dim, out_dim):
480
+ super().__init__()
481
+
482
+ self.proj = torch.nn.Sequential(
483
+ torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim),
484
+ torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim),
485
+ torch.nn.LayerNorm(out_dim))
486
+
487
+ def forward(self, image_embeds):
488
+ clip_extra_context_tokens = self.proj(image_embeds)
489
+ return clip_extra_context_tokens
490
+
491
+
492
+ class WanModel(nn.Module):
493
+ r"""
494
+ Wan diffusion backbone supporting both text-to-video and image-to-video.
495
+ """
496
+
497
+ ignore_for_config = [
498
+ 'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size'
499
+ ]
500
+ _no_split_modules = ['WanAttentionBlock']
501
+
502
+ gradient_checkpointing = False
503
+
504
+ def __init__(self,
505
+ model_type='t2v',
506
+ patch_size=(1, 2, 2),
507
+ text_len=512,
508
+ in_dim=16,
509
+ dim=2048,
510
+ ffn_dim=13824,
511
+ freq_dim=256,
512
+ text_dim=4096,
513
+ out_dim=16,
514
+ num_heads=40,
515
+ num_layers=40,
516
+ window_size=(-1, -1),
517
+ qk_norm=True,
518
+ cross_attn_norm=True,
519
+ eps=1e-6,
520
+ audio_token_num=16,
521
+ insert_audio=True):
522
+ r"""
523
+ Initialize the diffusion model backbone.
524
+
525
+ Args:
526
+ model_type (`str`, *optional*, defaults to 't2v'):
527
+ Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
528
+ patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
529
+ 3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
530
+ text_len (`int`, *optional*, defaults to 512):
531
+ Fixed length for text embeddings
532
+ in_dim (`int`, *optional*, defaults to 16):
533
+ Input video channels (C_in)
534
+ dim (`int`, *optional*, defaults to 2048):
535
+ Hidden dimension of the transformer
536
+ ffn_dim (`int`, *optional*, defaults to 8192):
537
+ Intermediate dimension in feed-forward network
538
+ freq_dim (`int`, *optional*, defaults to 256):
539
+ Dimension for sinusoidal time embeddings
540
+ text_dim (`int`, *optional*, defaults to 4096):
541
+ Input dimension for text embeddings
542
+ out_dim (`int`, *optional*, defaults to 16):
543
+ Output video channels (C_out)
544
+ num_heads (`int`, *optional*, defaults to 16):
545
+ Number of attention heads
546
+ num_layers (`int`, *optional*, defaults to 32):
547
+ Number of transformer blocks
548
+ window_size (`tuple`, *optional*, defaults to (-1, -1)):
549
+ Window size for local attention (-1 indicates global attention)
550
+ qk_norm (`bool`, *optional*, defaults to True):
551
+ Enable query/key normalization
552
+ cross_attn_norm (`bool`, *optional*, defaults to False):
553
+ Enable cross-attention normalization
554
+ eps (`float`, *optional*, defaults to 1e-6):
555
+ Epsilon value for normalization layers
556
+ """
557
+
558
+ super().__init__()
559
+
560
+ assert model_type in ['t2v', 'i2v']
561
+ self.model_type = model_type
562
+
563
+ self.patch_size = patch_size
564
+ self.text_len = text_len
565
+ self.in_dim = in_dim
566
+ self.dim = dim
567
+ self.ffn_dim = ffn_dim
568
+ self.freq_dim = freq_dim
569
+ self.text_dim = text_dim
570
+ self.out_dim = out_dim
571
+ self.num_heads = num_heads
572
+ self.num_layers = num_layers
573
+ self.window_size = window_size
574
+ self.qk_norm = qk_norm
575
+ self.cross_attn_norm = cross_attn_norm
576
+ self.eps = eps
577
+
578
+ # embeddings
579
+ self.patch_embedding = nn.Conv3d(
580
+ in_dim, dim, kernel_size=patch_size, stride=patch_size)
581
+ self.text_embedding = nn.Sequential(
582
+ nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
583
+ nn.Linear(dim, dim))
584
+
585
+ self.time_embedding = nn.Sequential(
586
+ nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
587
+ self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
588
+
589
+ # blocks
590
+ cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
591
+ self.insert_audio = insert_audio
592
+ self.blocks = nn.ModuleList([
593
+ WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
594
+ window_size, qk_norm, cross_attn_norm,
595
+ eps, use_audio=self.insert_audio)
596
+ for _ in range(num_layers)
597
+ ])
598
+
599
+ # head
600
+ self.head = Head(dim, out_dim, patch_size, eps)
601
+
602
+ if self.insert_audio:
603
+ self.audio_proj = AudioProjModel(seq_len=8, blocks=5, channels=1280,
604
+ intermediate_dim=512, output_dim=1536, context_tokens=audio_token_num)
605
+
606
+ # RoPE freqs: register as a buffer so it moves with .to() / DDP and is tracked by compile
607
+ assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
608
+ d = dim // num_heads
609
+
610
+ _freqs = torch.cat([
611
+ rope_params(1024, d - 4 * (d // 6)),
612
+ rope_params(1024, 2 * (d // 6)),
613
+ rope_params(1024, 2 * (d // 6))
614
+ ], dim=1)
615
+ self.register_buffer("freqs", _freqs, persistent=False)
616
+
617
+ # initialize weights
618
+ self.init_weights()
619
+
620
+ # initialize unified parallel
621
+ if is_unified_parallel_initialized():
622
+ print(f"Initializing WanModel with unified parallel initialized")
623
+ from humo.models.distributed.dit_ulysses_sequence_parallel import ulysses_attn_forward, ulysses_dit_forward, ulysses_audio_cross_attn_forward
624
+ for block in self.blocks:
625
+ block.self_attn.forward = types.MethodType(ulysses_attn_forward, block.self_attn)
626
+ if block.use_audio:
627
+ block.audio_cross_attn_wrapper.audio_cross_attn.forward = types.MethodType(ulysses_audio_cross_attn_forward, block.audio_cross_attn_wrapper.audio_cross_attn)
628
+ self.forward = types.MethodType(ulysses_dit_forward, self)
629
+
630
+ def forward(
631
+ self,
632
+ x,
633
+ t,
634
+ context,
635
+ seq_len,
636
+ audio=None,
637
+ y=None,
638
+ tea_cache=None,
639
+ ):
640
+ r"""
641
+ Forward pass through the diffusion model
642
+
643
+ Args:
644
+ x (List[Tensor]):
645
+ List of input video tensors, each with shape [C_in, F, H, W]
646
+ t (Tensor):
647
+ Diffusion timesteps tensor of shape [B]
648
+ context (List[Tensor]):
649
+ List of text embeddings each with shape [L, C]
650
+ seq_len (`int`):
651
+ Maximum sequence length for positional encoding
652
+ clip_fea (Tensor, *optional*):
653
+ CLIP image features for image-to-video mode
654
+ y (List[Tensor], *optional*):
655
+ Conditional video inputs for image-to-video mode, same shape as x
656
+
657
+ Returns:
658
+ List[Tensor]:
659
+ List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
660
+ """
661
+ if self.model_type == 'i2v':
662
+ assert y is not None
663
+
664
+ # params
665
+ freqs = self.freqs
666
+
667
+ if y is not None:
668
+ x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
669
+
670
+ # embeddings
671
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
672
+ grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
673
+
674
+ x = [u.flatten(2).transpose(1, 2) for u in x]
675
+ seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
676
+ assert seq_lens.max() <= seq_len
677
+
678
+ # pad to uniform length and batch
679
+ x = torch.cat([
680
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
681
+ for u in x
682
+ ]) # shape: [B, seq_len, C]
683
+
684
+ # time embeddings
685
+ with torch.amp.autocast('cuda', dtype=torch.float32):
686
+ e = self.time_embedding(
687
+ sinusoidal_embedding_1d(self.freq_dim, t).float()
688
+ ).float()
689
+ e0 = self.time_projection(e).unflatten(1, (6, self.dim)).float()
690
+ assert e.dtype == torch.float32 and e0.dtype == torch.float32
691
+
692
+ # context
693
+ context_lens = None
694
+ context = self.text_embedding(
695
+ torch.stack([
696
+ torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
697
+ for u in context
698
+ ])
699
+ )
700
+
701
+ # audio (unchanged; not cached)
702
+ if self.insert_audio:
703
+ audio = [self.audio_proj(au.unsqueeze(0)).permute(0, 3, 1, 2) for au in audio]
704
+ audio_seq_len = max(au.shape[2] for au in audio) * audio[0].shape[3]
705
+
706
+ audio = [au.flatten(2).transpose(1, 2) for au in audio] # [1, t*32, 1536]
707
+ audio = torch.cat([
708
+ torch.cat([au, au.new_zeros(1, int(audio_seq_len) - au.size(1), au.size(2))], dim=1)
709
+ for au in audio
710
+ ])
711
+ else:
712
+ audio = None
713
+ audio_seq_len = None
714
+
715
+ # ---- tea_cache integration (mirrors your working model) ----
716
+ if tea_cache is not None:
717
+ # Use the pre-block tokens 'x' and time-mod 'e0' to decide whether to reuse cache
718
+ tea_cache_update = tea_cache.check(self, x, e0)
719
+ else:
720
+ tea_cache_update = False
721
+
722
+ ori_x_len = x.shape[1] # remember original token length before potential cache extension
723
+
724
+ if tea_cache_update:
725
+ # Let the cache inject/append any needed past states/tokens for reuse
726
+ x = tea_cache.update(x)
727
+ else:
728
+ # arguments for blocks
729
+ kwargs = dict(
730
+ e=e0,
731
+ seq_lens=seq_lens,
732
+ grid_sizes=grid_sizes,
733
+ freqs=freqs,
734
+ context=context,
735
+ context_lens=context_lens,
736
+ audio=audio,
737
+ audio_seq_len=audio_seq_len
738
+ )
739
+
740
+ # transformer blocks
741
+ for block in self.blocks:
742
+ x = block(x, **kwargs)
743
+
744
+ if tea_cache is not None:
745
+ x_cache = x[:, :ori_x_len]
746
+ tea_cache.store(x_cache)
747
+
748
+ # head
749
+ x = self.head(x, e)
750
+
751
+ # unpatchify
752
+ x = self.unpatchify(x, grid_sizes)
753
+ return [u.float() for u in x]
754
+
755
+
756
+ def unpatchify(self, x, grid_sizes):
757
+ r"""
758
+ Reconstruct video tensors from patch embeddings.
759
+
760
+ Args:
761
+ x (List[Tensor]):
762
+ List of patchified features, each with shape [L, C_out * prod(patch_size)]
763
+ grid_sizes (Tensor):
764
+ Original spatial-temporal grid dimensions before patching,
765
+ shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
766
+
767
+ Returns:
768
+ List[Tensor]:
769
+ Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
770
+ """
771
+
772
+ c = self.out_dim
773
+ out = []
774
+ for u, v in zip(x, grid_sizes.tolist()):
775
+ u = u[:math.prod(v)].view(*v, *self.patch_size, c)
776
+ u = torch.einsum('fhwpqrc->cfphqwr', u)
777
+ u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
778
+ out.append(u)
779
+ return out
780
+
781
+ def init_weights(self):
782
+ r"""
783
+ Initialize model parameters using Xavier initialization.
784
+ """
785
+
786
+ # basic init
787
+ for m in self.modules():
788
+ if isinstance(m, nn.Linear):
789
+ nn.init.xavier_uniform_(m.weight)
790
+ if m.bias is not None:
791
+ nn.init.zeros_(m.bias)
792
+
793
+ # init embeddings
794
+ nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
795
+ for m in self.text_embedding.modules():
796
+ if isinstance(m, nn.Linear):
797
+ nn.init.normal_(m.weight, std=.02)
798
+ for m in self.time_embedding.modules():
799
+ if isinstance(m, nn.Linear):
800
+ nn.init.normal_(m.weight, std=.02)
801
+
802
+ # init output layer
803
+ nn.init.zeros_(self.head.head.weight)
humo/models/wan_modules/t5.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from transformers.models.t5.modeling_t5
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+ import logging
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from .tokenizers import HuggingfaceTokenizer
11
+
12
+ __all__ = [
13
+ 'T5Model',
14
+ 'T5Encoder',
15
+ 'T5Decoder',
16
+ 'T5EncoderModel',
17
+ ]
18
+
19
+
20
+ def fp16_clamp(x):
21
+ if x.dtype == torch.float16 and torch.isinf(x).any():
22
+ clamp = torch.finfo(x.dtype).max - 1000
23
+ x = torch.clamp(x, min=-clamp, max=clamp)
24
+ return x
25
+
26
+
27
+ def init_weights(m):
28
+ if isinstance(m, T5LayerNorm):
29
+ nn.init.ones_(m.weight)
30
+ elif isinstance(m, T5Model):
31
+ nn.init.normal_(m.token_embedding.weight, std=1.0)
32
+ elif isinstance(m, T5FeedForward):
33
+ nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
34
+ nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
35
+ nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
36
+ elif isinstance(m, T5Attention):
37
+ nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5)
38
+ nn.init.normal_(m.k.weight, std=m.dim**-0.5)
39
+ nn.init.normal_(m.v.weight, std=m.dim**-0.5)
40
+ nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5)
41
+ elif isinstance(m, T5RelativeEmbedding):
42
+ nn.init.normal_(
43
+ m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5)
44
+
45
+
46
+ class GELU(nn.Module):
47
+
48
+ def forward(self, x):
49
+ return 0.5 * x * (1.0 + torch.tanh(
50
+ math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
51
+
52
+
53
+ class T5LayerNorm(nn.Module):
54
+
55
+ def __init__(self, dim, eps=1e-6):
56
+ super(T5LayerNorm, self).__init__()
57
+ self.dim = dim
58
+ self.eps = eps
59
+ self.weight = nn.Parameter(torch.ones(dim))
60
+
61
+ def forward(self, x):
62
+ x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +
63
+ self.eps)
64
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
65
+ x = x.type_as(self.weight)
66
+ return self.weight * x
67
+
68
+
69
+ class T5Attention(nn.Module):
70
+
71
+ def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
72
+ assert dim_attn % num_heads == 0
73
+ super(T5Attention, self).__init__()
74
+ self.dim = dim
75
+ self.dim_attn = dim_attn
76
+ self.num_heads = num_heads
77
+ self.head_dim = dim_attn // num_heads
78
+
79
+ # layers
80
+ self.q = nn.Linear(dim, dim_attn, bias=False)
81
+ self.k = nn.Linear(dim, dim_attn, bias=False)
82
+ self.v = nn.Linear(dim, dim_attn, bias=False)
83
+ self.o = nn.Linear(dim_attn, dim, bias=False)
84
+ self.dropout = nn.Dropout(dropout)
85
+
86
+ def forward(self, x, context=None, mask=None, pos_bias=None):
87
+ """
88
+ x: [B, L1, C].
89
+ context: [B, L2, C] or None.
90
+ mask: [B, L2] or [B, L1, L2] or None.
91
+ """
92
+ # check inputs
93
+ context = x if context is None else context
94
+ b, n, c = x.size(0), self.num_heads, self.head_dim
95
+
96
+ # compute query, key, value
97
+ q = self.q(x).view(b, -1, n, c)
98
+ k = self.k(context).view(b, -1, n, c)
99
+ v = self.v(context).view(b, -1, n, c)
100
+
101
+ # attention bias
102
+ attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
103
+ if pos_bias is not None:
104
+ attn_bias += pos_bias
105
+ if mask is not None:
106
+ assert mask.ndim in [2, 3]
107
+ mask = mask.view(b, 1, 1,
108
+ -1) if mask.ndim == 2 else mask.unsqueeze(1)
109
+ attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
110
+
111
+ # compute attention (T5 does not use scaling)
112
+ attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias
113
+ attn = F.softmax(attn.float(), dim=-1).type_as(attn)
114
+ x = torch.einsum('bnij,bjnc->binc', attn, v)
115
+
116
+ # output
117
+ x = x.reshape(b, -1, n * c)
118
+ x = self.o(x)
119
+ x = self.dropout(x)
120
+ return x
121
+
122
+
123
+ class T5FeedForward(nn.Module):
124
+
125
+ def __init__(self, dim, dim_ffn, dropout=0.1):
126
+ super(T5FeedForward, self).__init__()
127
+ self.dim = dim
128
+ self.dim_ffn = dim_ffn
129
+
130
+ # layers
131
+ self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
132
+ self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
133
+ self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
134
+ self.dropout = nn.Dropout(dropout)
135
+
136
+ def forward(self, x):
137
+ x = self.fc1(x) * self.gate(x)
138
+ x = self.dropout(x)
139
+ x = self.fc2(x)
140
+ x = self.dropout(x)
141
+ return x
142
+
143
+
144
+ class T5SelfAttention(nn.Module):
145
+
146
+ def __init__(self,
147
+ dim,
148
+ dim_attn,
149
+ dim_ffn,
150
+ num_heads,
151
+ num_buckets,
152
+ shared_pos=True,
153
+ dropout=0.1):
154
+ super(T5SelfAttention, self).__init__()
155
+ self.dim = dim
156
+ self.dim_attn = dim_attn
157
+ self.dim_ffn = dim_ffn
158
+ self.num_heads = num_heads
159
+ self.num_buckets = num_buckets
160
+ self.shared_pos = shared_pos
161
+
162
+ # layers
163
+ self.norm1 = T5LayerNorm(dim)
164
+ self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
165
+ self.norm2 = T5LayerNorm(dim)
166
+ self.ffn = T5FeedForward(dim, dim_ffn, dropout)
167
+ self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
168
+ num_buckets, num_heads, bidirectional=True)
169
+
170
+ def forward(self, x, mask=None, pos_bias=None):
171
+ e = pos_bias if self.shared_pos else self.pos_embedding(
172
+ x.size(1), x.size(1))
173
+ x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
174
+ x = fp16_clamp(x + self.ffn(self.norm2(x)))
175
+ return x
176
+
177
+
178
+ class T5CrossAttention(nn.Module):
179
+
180
+ def __init__(self,
181
+ dim,
182
+ dim_attn,
183
+ dim_ffn,
184
+ num_heads,
185
+ num_buckets,
186
+ shared_pos=True,
187
+ dropout=0.1):
188
+ super(T5CrossAttention, self).__init__()
189
+ self.dim = dim
190
+ self.dim_attn = dim_attn
191
+ self.dim_ffn = dim_ffn
192
+ self.num_heads = num_heads
193
+ self.num_buckets = num_buckets
194
+ self.shared_pos = shared_pos
195
+
196
+ # layers
197
+ self.norm1 = T5LayerNorm(dim)
198
+ self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
199
+ self.norm2 = T5LayerNorm(dim)
200
+ self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
201
+ self.norm3 = T5LayerNorm(dim)
202
+ self.ffn = T5FeedForward(dim, dim_ffn, dropout)
203
+ self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
204
+ num_buckets, num_heads, bidirectional=False)
205
+
206
+ def forward(self,
207
+ x,
208
+ mask=None,
209
+ encoder_states=None,
210
+ encoder_mask=None,
211
+ pos_bias=None):
212
+ e = pos_bias if self.shared_pos else self.pos_embedding(
213
+ x.size(1), x.size(1))
214
+ x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
215
+ x = fp16_clamp(x + self.cross_attn(
216
+ self.norm2(x), context=encoder_states, mask=encoder_mask))
217
+ x = fp16_clamp(x + self.ffn(self.norm3(x)))
218
+ return x
219
+
220
+
221
+ class T5RelativeEmbedding(nn.Module):
222
+
223
+ def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
224
+ super(T5RelativeEmbedding, self).__init__()
225
+ self.num_buckets = num_buckets
226
+ self.num_heads = num_heads
227
+ self.bidirectional = bidirectional
228
+ self.max_dist = max_dist
229
+
230
+ # layers
231
+ self.embedding = nn.Embedding(num_buckets, num_heads)
232
+
233
+ def forward(self, lq, lk):
234
+ device = self.embedding.weight.device
235
+ # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
236
+ # torch.arange(lq).unsqueeze(1).to(device)
237
+ rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \
238
+ torch.arange(lq, device=device).unsqueeze(1)
239
+ rel_pos = self._relative_position_bucket(rel_pos)
240
+ rel_pos_embeds = self.embedding(rel_pos)
241
+ rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(
242
+ 0) # [1, N, Lq, Lk]
243
+ return rel_pos_embeds.contiguous()
244
+
245
+ def _relative_position_bucket(self, rel_pos):
246
+ # preprocess
247
+ if self.bidirectional:
248
+ num_buckets = self.num_buckets // 2
249
+ rel_buckets = (rel_pos > 0).long() * num_buckets
250
+ rel_pos = torch.abs(rel_pos)
251
+ else:
252
+ num_buckets = self.num_buckets
253
+ rel_buckets = 0
254
+ rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
255
+
256
+ # embeddings for small and large positions
257
+ max_exact = num_buckets // 2
258
+ rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) /
259
+ math.log(self.max_dist / max_exact) *
260
+ (num_buckets - max_exact)).long()
261
+ rel_pos_large = torch.min(
262
+ rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
263
+ rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
264
+ return rel_buckets
265
+
266
+
267
+ class T5Encoder(nn.Module):
268
+
269
+ def __init__(self,
270
+ vocab,
271
+ dim,
272
+ dim_attn,
273
+ dim_ffn,
274
+ num_heads,
275
+ num_layers,
276
+ num_buckets,
277
+ shared_pos=True,
278
+ dropout=0.1):
279
+ super(T5Encoder, self).__init__()
280
+ self.dim = dim
281
+ self.dim_attn = dim_attn
282
+ self.dim_ffn = dim_ffn
283
+ self.num_heads = num_heads
284
+ self.num_layers = num_layers
285
+ self.num_buckets = num_buckets
286
+ self.shared_pos = shared_pos
287
+
288
+ # layers
289
+ self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
290
+ else nn.Embedding(vocab, dim)
291
+ self.pos_embedding = T5RelativeEmbedding(
292
+ num_buckets, num_heads, bidirectional=True) if shared_pos else None
293
+ self.dropout = nn.Dropout(dropout)
294
+ self.blocks = nn.ModuleList([
295
+ T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
296
+ shared_pos, dropout) for _ in range(num_layers)
297
+ ])
298
+ self.norm = T5LayerNorm(dim)
299
+
300
+ # initialize weights
301
+ self.apply(init_weights)
302
+
303
+ def forward(self, ids, mask=None):
304
+ x = self.token_embedding(ids)
305
+ x = self.dropout(x)
306
+ e = self.pos_embedding(x.size(1),
307
+ x.size(1)) if self.shared_pos else None
308
+ for block in self.blocks:
309
+ x = block(x, mask, pos_bias=e)
310
+ x = self.norm(x)
311
+ x = self.dropout(x)
312
+ return x
313
+
314
+
315
+ class T5Decoder(nn.Module):
316
+
317
+ def __init__(self,
318
+ vocab,
319
+ dim,
320
+ dim_attn,
321
+ dim_ffn,
322
+ num_heads,
323
+ num_layers,
324
+ num_buckets,
325
+ shared_pos=True,
326
+ dropout=0.1):
327
+ super(T5Decoder, self).__init__()
328
+ self.dim = dim
329
+ self.dim_attn = dim_attn
330
+ self.dim_ffn = dim_ffn
331
+ self.num_heads = num_heads
332
+ self.num_layers = num_layers
333
+ self.num_buckets = num_buckets
334
+ self.shared_pos = shared_pos
335
+
336
+ # layers
337
+ self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
338
+ else nn.Embedding(vocab, dim)
339
+ self.pos_embedding = T5RelativeEmbedding(
340
+ num_buckets, num_heads, bidirectional=False) if shared_pos else None
341
+ self.dropout = nn.Dropout(dropout)
342
+ self.blocks = nn.ModuleList([
343
+ T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
344
+ shared_pos, dropout) for _ in range(num_layers)
345
+ ])
346
+ self.norm = T5LayerNorm(dim)
347
+
348
+ # initialize weights
349
+ self.apply(init_weights)
350
+
351
+ def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None):
352
+ b, s = ids.size()
353
+
354
+ # causal mask
355
+ if mask is None:
356
+ mask = torch.tril(torch.ones(1, s, s).to(ids.device))
357
+ elif mask.ndim == 2:
358
+ mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1))
359
+
360
+ # layers
361
+ x = self.token_embedding(ids)
362
+ x = self.dropout(x)
363
+ e = self.pos_embedding(x.size(1),
364
+ x.size(1)) if self.shared_pos else None
365
+ for block in self.blocks:
366
+ x = block(x, mask, encoder_states, encoder_mask, pos_bias=e)
367
+ x = self.norm(x)
368
+ x = self.dropout(x)
369
+ return x
370
+
371
+
372
+ class T5Model(nn.Module):
373
+
374
+ def __init__(self,
375
+ vocab_size,
376
+ dim,
377
+ dim_attn,
378
+ dim_ffn,
379
+ num_heads,
380
+ encoder_layers,
381
+ decoder_layers,
382
+ num_buckets,
383
+ shared_pos=True,
384
+ dropout=0.1):
385
+ super(T5Model, self).__init__()
386
+ self.vocab_size = vocab_size
387
+ self.dim = dim
388
+ self.dim_attn = dim_attn
389
+ self.dim_ffn = dim_ffn
390
+ self.num_heads = num_heads
391
+ self.encoder_layers = encoder_layers
392
+ self.decoder_layers = decoder_layers
393
+ self.num_buckets = num_buckets
394
+
395
+ # layers
396
+ self.token_embedding = nn.Embedding(vocab_size, dim)
397
+ self.encoder = T5Encoder(self.token_embedding, dim, dim_attn, dim_ffn,
398
+ num_heads, encoder_layers, num_buckets,
399
+ shared_pos, dropout)
400
+ self.decoder = T5Decoder(self.token_embedding, dim, dim_attn, dim_ffn,
401
+ num_heads, decoder_layers, num_buckets,
402
+ shared_pos, dropout)
403
+ self.head = nn.Linear(dim, vocab_size, bias=False)
404
+
405
+ # initialize weights
406
+ self.apply(init_weights)
407
+
408
+ def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
409
+ x = self.encoder(encoder_ids, encoder_mask)
410
+ x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask)
411
+ x = self.head(x)
412
+ return x
413
+
414
+
415
+ def _t5(name,
416
+ encoder_only=False,
417
+ decoder_only=False,
418
+ return_tokenizer=False,
419
+ tokenizer_kwargs={},
420
+ dtype=torch.float32,
421
+ device='cpu',
422
+ **kwargs):
423
+ # sanity check
424
+ assert not (encoder_only and decoder_only)
425
+
426
+ # params
427
+ if encoder_only:
428
+ model_cls = T5Encoder
429
+ kwargs['vocab'] = kwargs.pop('vocab_size')
430
+ kwargs['num_layers'] = kwargs.pop('encoder_layers')
431
+ _ = kwargs.pop('decoder_layers')
432
+ elif decoder_only:
433
+ model_cls = T5Decoder
434
+ kwargs['vocab'] = kwargs.pop('vocab_size')
435
+ kwargs['num_layers'] = kwargs.pop('decoder_layers')
436
+ _ = kwargs.pop('encoder_layers')
437
+ else:
438
+ model_cls = T5Model
439
+
440
+ # init model
441
+ with torch.device(device):
442
+ model = model_cls(**kwargs)
443
+
444
+ # set device
445
+ model = model.to(dtype=dtype, device=device)
446
+
447
+ # init tokenizer
448
+ if return_tokenizer:
449
+ from .tokenizers import HuggingfaceTokenizer
450
+ tokenizer = HuggingfaceTokenizer(f'google/{name}', **tokenizer_kwargs)
451
+ return model, tokenizer
452
+ else:
453
+ return model
454
+
455
+
456
+ def umt5_xxl(**kwargs):
457
+ cfg = dict(
458
+ vocab_size=256384,
459
+ dim=4096,
460
+ dim_attn=4096,
461
+ dim_ffn=10240,
462
+ num_heads=64,
463
+ encoder_layers=24,
464
+ decoder_layers=24,
465
+ num_buckets=32,
466
+ shared_pos=False,
467
+ dropout=0.1)
468
+ cfg.update(**kwargs)
469
+ return _t5('umt5-xxl', **cfg)
470
+
471
+
472
+ class T5EncoderModel(nn.Module):
473
+
474
+ def __init__(
475
+ self,
476
+ text_len,
477
+ dtype=torch.bfloat16,
478
+ device=torch.cuda.current_device(),
479
+ checkpoint_path=None,
480
+ tokenizer_path=None,
481
+ shard_fn=None,
482
+ ):
483
+ super(T5EncoderModel, self).__init__()
484
+ self.text_len = text_len
485
+ self.dtype = dtype
486
+ self.device = device
487
+ self.checkpoint_path = checkpoint_path
488
+ self.tokenizer_path = tokenizer_path
489
+
490
+ with torch.device(device):
491
+ self.model = T5Encoder(
492
+ vocab=256384,
493
+ dim=4096,
494
+ dim_attn=4096,
495
+ dim_ffn=10240,
496
+ num_heads=64,
497
+ num_layers=24,
498
+ num_buckets=32,
499
+ shared_pos=False,
500
+ dropout=0.1
501
+ )
502
+ # set device
503
+ self.model = self.model.to(dtype=dtype, device=device).eval().requires_grad_(False)
504
+
505
+ logging.info(f'loading {checkpoint_path}')
506
+ if checkpoint_path is not None:
507
+ self.model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
508
+
509
+ if shard_fn is not None:
510
+ self.model = shard_fn(self.model, sync_module_states=False)
511
+ else:
512
+ self.model.to(self.device)
513
+ # init tokenizer
514
+ self.tokenizer = HuggingfaceTokenizer(
515
+ name=tokenizer_path, seq_len=text_len, clean='whitespace')
516
+
517
+ @torch.no_grad()
518
+ def __call__(self, texts, device):
519
+ ids, mask = self.tokenizer(
520
+ texts, return_mask=True, add_special_tokens=True)
521
+ ids = ids.to(device)
522
+ mask = mask.to(device)
523
+ seq_lens = mask.gt(0).sum(dim=1).long()
524
+ context = self.model(ids, mask)
525
+ return [u[:v] for u, v in zip(context, seq_lens)]
humo/models/wan_modules/tokenizers.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import html
3
+ import string
4
+
5
+ import ftfy
6
+ import regex as re
7
+ from transformers import AutoTokenizer
8
+
9
+ __all__ = ['HuggingfaceTokenizer']
10
+
11
+
12
+ def basic_clean(text):
13
+ text = ftfy.fix_text(text)
14
+ text = html.unescape(html.unescape(text))
15
+ return text.strip()
16
+
17
+
18
+ def whitespace_clean(text):
19
+ text = re.sub(r'\s+', ' ', text)
20
+ text = text.strip()
21
+ return text
22
+
23
+
24
+ def canonicalize(text, keep_punctuation_exact_string=None):
25
+ text = text.replace('_', ' ')
26
+ if keep_punctuation_exact_string:
27
+ text = keep_punctuation_exact_string.join(
28
+ part.translate(str.maketrans('', '', string.punctuation))
29
+ for part in text.split(keep_punctuation_exact_string))
30
+ else:
31
+ text = text.translate(str.maketrans('', '', string.punctuation))
32
+ text = text.lower()
33
+ text = re.sub(r'\s+', ' ', text)
34
+ return text.strip()
35
+
36
+
37
+ class HuggingfaceTokenizer:
38
+
39
+ def __init__(self, name, seq_len=None, clean=None, **kwargs):
40
+ assert clean in (None, 'whitespace', 'lower', 'canonicalize')
41
+ self.name = name
42
+ self.seq_len = seq_len
43
+ self.clean = clean
44
+
45
+ # init tokenizer
46
+ self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
47
+ self.vocab_size = self.tokenizer.vocab_size
48
+
49
+ def __call__(self, sequence, **kwargs):
50
+ return_mask = kwargs.pop('return_mask', False)
51
+
52
+ # arguments
53
+ _kwargs = {'return_tensors': 'pt'}
54
+ if self.seq_len is not None:
55
+ _kwargs.update({
56
+ 'padding': 'max_length',
57
+ 'truncation': True,
58
+ 'max_length': self.seq_len
59
+ })
60
+ _kwargs.update(**kwargs)
61
+
62
+ # tokenization
63
+ if isinstance(sequence, str):
64
+ sequence = [sequence]
65
+ if self.clean:
66
+ sequence = [self._clean(u) for u in sequence]
67
+ ids = self.tokenizer(sequence, **_kwargs)
68
+
69
+ # output
70
+ if return_mask:
71
+ return ids.input_ids, ids.attention_mask
72
+ else:
73
+ return ids.input_ids
74
+
75
+ def _clean(self, text):
76
+ if self.clean == 'whitespace':
77
+ text = whitespace_clean(basic_clean(text))
78
+ elif self.clean == 'lower':
79
+ text = whitespace_clean(basic_clean(text)).lower()
80
+ elif self.clean == 'canonicalize':
81
+ text = canonicalize(basic_clean(text))
82
+ return text
humo/models/wan_modules/vae.py ADDED
@@ -0,0 +1,666 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import logging
3
+
4
+ import torch
5
+ import torch.cuda.amp as amp
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+
10
+ __all__ = [
11
+ 'WanVAE',
12
+ ]
13
+
14
+ CACHE_T = 2
15
+
16
+
17
+ class CausalConv3d(nn.Conv3d):
18
+ """
19
+ Causal 3d convolusion.
20
+ """
21
+
22
+ def __init__(self, *args, **kwargs):
23
+ super().__init__(*args, **kwargs)
24
+ self._padding = (self.padding[2], self.padding[2], self.padding[1],
25
+ self.padding[1], 2 * self.padding[0], 0)
26
+ self.padding = (0, 0, 0)
27
+
28
+ def forward(self, x, cache_x=None):
29
+ padding = list(self._padding)
30
+ if cache_x is not None and self._padding[4] > 0:
31
+ cache_x = cache_x.to(x.device)
32
+ x = torch.cat([cache_x, x], dim=2)
33
+ padding[4] -= cache_x.shape[2]
34
+ x = F.pad(x, padding)
35
+
36
+ return super().forward(x)
37
+
38
+
39
+ class RMS_norm(nn.Module):
40
+
41
+ def __init__(self, dim, channel_first=True, images=True, bias=False):
42
+ super().__init__()
43
+ broadcastable_dims = (1, 1, 1) if not images else (1, 1)
44
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
45
+
46
+ self.channel_first = channel_first
47
+ self.scale = dim**0.5
48
+ self.gamma = nn.Parameter(torch.ones(shape))
49
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
50
+
51
+ def forward(self, x):
52
+ return F.normalize(
53
+ x, dim=(1 if self.channel_first else
54
+ -1)) * self.scale * self.gamma + self.bias
55
+
56
+
57
+ class Upsample(nn.Upsample):
58
+
59
+ def forward(self, x):
60
+ """
61
+ Fix bfloat16 support for nearest neighbor interpolation.
62
+ """
63
+ return super().forward(x.float()).type_as(x)
64
+
65
+
66
+ class Resample(nn.Module):
67
+
68
+ def __init__(self, dim, mode):
69
+ assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
70
+ 'downsample3d')
71
+ super().__init__()
72
+ self.dim = dim
73
+ self.mode = mode
74
+
75
+ # layers
76
+ if mode == 'upsample2d':
77
+ self.resample = nn.Sequential(
78
+ Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
79
+ nn.Conv2d(dim, dim // 2, 3, padding=1))
80
+ elif mode == 'upsample3d':
81
+ self.resample = nn.Sequential(
82
+ Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
83
+ nn.Conv2d(dim, dim // 2, 3, padding=1))
84
+ self.time_conv = CausalConv3d(
85
+ dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
86
+
87
+ elif mode == 'downsample2d':
88
+ self.resample = nn.Sequential(
89
+ nn.ZeroPad2d((0, 1, 0, 1)),
90
+ nn.Conv2d(dim, dim, 3, stride=(2, 2)))
91
+ elif mode == 'downsample3d':
92
+ self.resample = nn.Sequential(
93
+ nn.ZeroPad2d((0, 1, 0, 1)),
94
+ nn.Conv2d(dim, dim, 3, stride=(2, 2)))
95
+ self.time_conv = CausalConv3d(
96
+ dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
97
+
98
+ else:
99
+ self.resample = nn.Identity()
100
+
101
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
102
+ b, c, t, h, w = x.size()
103
+ if self.mode == 'upsample3d':
104
+ if feat_cache is not None:
105
+ idx = feat_idx[0]
106
+ if feat_cache[idx] is None:
107
+ feat_cache[idx] = 'Rep'
108
+ feat_idx[0] += 1
109
+ else:
110
+
111
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
112
+ if cache_x.shape[2] < 2 and feat_cache[
113
+ idx] is not None and feat_cache[idx] != 'Rep':
114
+ # cache last frame of last two chunk
115
+ cache_x = torch.cat([
116
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
117
+ cache_x.device), cache_x
118
+ ],
119
+ dim=2)
120
+ if cache_x.shape[2] < 2 and feat_cache[
121
+ idx] is not None and feat_cache[idx] == 'Rep':
122
+ cache_x = torch.cat([
123
+ torch.zeros_like(cache_x).to(cache_x.device),
124
+ cache_x
125
+ ],
126
+ dim=2)
127
+ if feat_cache[idx] == 'Rep':
128
+ x = self.time_conv(x)
129
+ else:
130
+ x = self.time_conv(x, feat_cache[idx])
131
+ feat_cache[idx] = cache_x
132
+ feat_idx[0] += 1
133
+
134
+ x = x.reshape(b, 2, c, t, h, w)
135
+ x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
136
+ 3)
137
+ x = x.reshape(b, c, t * 2, h, w)
138
+ t = x.shape[2]
139
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
140
+ x = self.resample(x)
141
+ x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
142
+
143
+ if self.mode == 'downsample3d':
144
+ if feat_cache is not None:
145
+ idx = feat_idx[0]
146
+ if feat_cache[idx] is None:
147
+ feat_cache[idx] = x.clone()
148
+ feat_idx[0] += 1
149
+ else:
150
+
151
+ cache_x = x[:, :, -1:, :, :].clone()
152
+ # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
153
+ # # cache last frame of last two chunk
154
+ # cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
155
+
156
+ x = self.time_conv(
157
+ torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
158
+ feat_cache[idx] = cache_x
159
+ feat_idx[0] += 1
160
+ return x
161
+
162
+ def init_weight(self, conv):
163
+ conv_weight = conv.weight
164
+ nn.init.zeros_(conv_weight)
165
+ c1, c2, t, h, w = conv_weight.size()
166
+ one_matrix = torch.eye(c1, c2)
167
+ init_matrix = one_matrix
168
+ nn.init.zeros_(conv_weight)
169
+ #conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
170
+ conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5
171
+ conv.weight.data.copy_(conv_weight)
172
+ nn.init.zeros_(conv.bias.data)
173
+
174
+ def init_weight2(self, conv):
175
+ conv_weight = conv.weight.data
176
+ nn.init.zeros_(conv_weight)
177
+ c1, c2, t, h, w = conv_weight.size()
178
+ init_matrix = torch.eye(c1 // 2, c2)
179
+ #init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
180
+ conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
181
+ conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
182
+ conv.weight.data.copy_(conv_weight)
183
+ nn.init.zeros_(conv.bias.data)
184
+
185
+
186
+ class ResidualBlock(nn.Module):
187
+
188
+ def __init__(self, in_dim, out_dim, dropout=0.0):
189
+ super().__init__()
190
+ self.in_dim = in_dim
191
+ self.out_dim = out_dim
192
+
193
+ # layers
194
+ self.residual = nn.Sequential(
195
+ RMS_norm(in_dim, images=False), nn.SiLU(),
196
+ CausalConv3d(in_dim, out_dim, 3, padding=1),
197
+ RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
198
+ CausalConv3d(out_dim, out_dim, 3, padding=1))
199
+ self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
200
+ if in_dim != out_dim else nn.Identity()
201
+
202
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
203
+ h = self.shortcut(x)
204
+ for layer in self.residual:
205
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
206
+ idx = feat_idx[0]
207
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
208
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
209
+ # cache last frame of last two chunk
210
+ cache_x = torch.cat([
211
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
212
+ cache_x.device), cache_x
213
+ ],
214
+ dim=2)
215
+ x = layer(x, feat_cache[idx])
216
+ feat_cache[idx] = cache_x
217
+ feat_idx[0] += 1
218
+ else:
219
+ x = layer(x)
220
+ return x + h
221
+
222
+
223
+ class AttentionBlock(nn.Module):
224
+ """
225
+ Causal self-attention with a single head.
226
+ """
227
+
228
+ def __init__(self, dim):
229
+ super().__init__()
230
+ self.dim = dim
231
+
232
+ # layers
233
+ self.norm = RMS_norm(dim)
234
+ self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
235
+ self.proj = nn.Conv2d(dim, dim, 1)
236
+
237
+ # zero out the last layer params
238
+ nn.init.zeros_(self.proj.weight)
239
+
240
+ def forward(self, x):
241
+ identity = x
242
+ b, c, t, h, w = x.size()
243
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
244
+ x = self.norm(x)
245
+ # compute query, key, value
246
+ q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3,
247
+ -1).permute(0, 1, 3,
248
+ 2).contiguous().chunk(
249
+ 3, dim=-1)
250
+
251
+ # apply attention
252
+ x = F.scaled_dot_product_attention(
253
+ q,
254
+ k,
255
+ v,
256
+ )
257
+ x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
258
+
259
+ # output
260
+ x = self.proj(x)
261
+ x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
262
+ return x + identity
263
+
264
+
265
+ class Encoder3d(nn.Module):
266
+
267
+ def __init__(self,
268
+ dim=128,
269
+ z_dim=4,
270
+ dim_mult=[1, 2, 4, 4],
271
+ num_res_blocks=2,
272
+ attn_scales=[],
273
+ temperal_downsample=[True, True, False],
274
+ dropout=0.0):
275
+ super().__init__()
276
+ self.dim = dim
277
+ self.z_dim = z_dim
278
+ self.dim_mult = dim_mult
279
+ self.num_res_blocks = num_res_blocks
280
+ self.attn_scales = attn_scales
281
+ self.temperal_downsample = temperal_downsample
282
+
283
+ # dimensions
284
+ dims = [dim * u for u in [1] + dim_mult]
285
+ scale = 1.0
286
+
287
+ # init block
288
+ self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
289
+
290
+ # downsample blocks
291
+ downsamples = []
292
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
293
+ # residual (+attention) blocks
294
+ for _ in range(num_res_blocks):
295
+ downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
296
+ if scale in attn_scales:
297
+ downsamples.append(AttentionBlock(out_dim))
298
+ in_dim = out_dim
299
+
300
+ # downsample block
301
+ if i != len(dim_mult) - 1:
302
+ mode = 'downsample3d' if temperal_downsample[
303
+ i] else 'downsample2d'
304
+ downsamples.append(Resample(out_dim, mode=mode))
305
+ scale /= 2.0
306
+ self.downsamples = nn.Sequential(*downsamples)
307
+
308
+ # middle blocks
309
+ self.middle = nn.Sequential(
310
+ ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim),
311
+ ResidualBlock(out_dim, out_dim, dropout))
312
+
313
+ # output blocks
314
+ self.head = nn.Sequential(
315
+ RMS_norm(out_dim, images=False), nn.SiLU(),
316
+ CausalConv3d(out_dim, z_dim, 3, padding=1))
317
+
318
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
319
+ if feat_cache is not None:
320
+ idx = feat_idx[0]
321
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
322
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
323
+ # cache last frame of last two chunk
324
+ cache_x = torch.cat([
325
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
326
+ cache_x.device), cache_x
327
+ ],
328
+ dim=2)
329
+ x = self.conv1(x, feat_cache[idx])
330
+ feat_cache[idx] = cache_x
331
+ feat_idx[0] += 1
332
+ else:
333
+ x = self.conv1(x)
334
+
335
+ ## downsamples
336
+ for layer in self.downsamples:
337
+ if feat_cache is not None:
338
+ x = layer(x, feat_cache, feat_idx)
339
+ else:
340
+ x = layer(x)
341
+
342
+ ## middle
343
+ for layer in self.middle:
344
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
345
+ x = layer(x, feat_cache, feat_idx)
346
+ else:
347
+ x = layer(x)
348
+
349
+ ## head
350
+ for layer in self.head:
351
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
352
+ idx = feat_idx[0]
353
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
354
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
355
+ # cache last frame of last two chunk
356
+ cache_x = torch.cat([
357
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
358
+ cache_x.device), cache_x
359
+ ],
360
+ dim=2)
361
+ x = layer(x, feat_cache[idx])
362
+ feat_cache[idx] = cache_x
363
+ feat_idx[0] += 1
364
+ else:
365
+ x = layer(x)
366
+ return x
367
+
368
+
369
+ class Decoder3d(nn.Module):
370
+
371
+ def __init__(self,
372
+ dim=128,
373
+ z_dim=4,
374
+ dim_mult=[1, 2, 4, 4],
375
+ num_res_blocks=2,
376
+ attn_scales=[],
377
+ temperal_upsample=[False, True, True],
378
+ dropout=0.0):
379
+ super().__init__()
380
+ self.dim = dim
381
+ self.z_dim = z_dim
382
+ self.dim_mult = dim_mult
383
+ self.num_res_blocks = num_res_blocks
384
+ self.attn_scales = attn_scales
385
+ self.temperal_upsample = temperal_upsample
386
+
387
+ # dimensions
388
+ dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
389
+ scale = 1.0 / 2**(len(dim_mult) - 2)
390
+
391
+ # init block
392
+ self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
393
+
394
+ # middle blocks
395
+ self.middle = nn.Sequential(
396
+ ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
397
+ ResidualBlock(dims[0], dims[0], dropout))
398
+
399
+ # upsample blocks
400
+ upsamples = []
401
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
402
+ # residual (+attention) blocks
403
+ if i == 1 or i == 2 or i == 3:
404
+ in_dim = in_dim // 2
405
+ for _ in range(num_res_blocks + 1):
406
+ upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
407
+ if scale in attn_scales:
408
+ upsamples.append(AttentionBlock(out_dim))
409
+ in_dim = out_dim
410
+
411
+ # upsample block
412
+ if i != len(dim_mult) - 1:
413
+ mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
414
+ upsamples.append(Resample(out_dim, mode=mode))
415
+ scale *= 2.0
416
+ self.upsamples = nn.Sequential(*upsamples)
417
+
418
+ # output blocks
419
+ self.head = nn.Sequential(
420
+ RMS_norm(out_dim, images=False), nn.SiLU(),
421
+ CausalConv3d(out_dim, 3, 3, padding=1))
422
+
423
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
424
+ ## conv1
425
+ if feat_cache is not None:
426
+ idx = feat_idx[0]
427
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
428
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
429
+ # cache last frame of last two chunk
430
+ cache_x = torch.cat([
431
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
432
+ cache_x.device), cache_x
433
+ ],
434
+ dim=2)
435
+ x = self.conv1(x, feat_cache[idx])
436
+ feat_cache[idx] = cache_x
437
+ feat_idx[0] += 1
438
+ else:
439
+ x = self.conv1(x)
440
+
441
+ ## middle
442
+ for layer in self.middle:
443
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
444
+ x = layer(x, feat_cache, feat_idx)
445
+ else:
446
+ x = layer(x)
447
+
448
+ ## upsamples
449
+ for layer in self.upsamples:
450
+ if feat_cache is not None:
451
+ x = layer(x, feat_cache, feat_idx)
452
+ else:
453
+ x = layer(x)
454
+
455
+ ## head
456
+ for layer in self.head:
457
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
458
+ idx = feat_idx[0]
459
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
460
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
461
+ # cache last frame of last two chunk
462
+ cache_x = torch.cat([
463
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
464
+ cache_x.device), cache_x
465
+ ],
466
+ dim=2)
467
+ x = layer(x, feat_cache[idx])
468
+ feat_cache[idx] = cache_x
469
+ feat_idx[0] += 1
470
+ else:
471
+ x = layer(x)
472
+ return x
473
+
474
+
475
+ def count_conv3d(model):
476
+ count = 0
477
+ for m in model.modules():
478
+ if isinstance(m, CausalConv3d):
479
+ count += 1
480
+ return count
481
+
482
+
483
+ class WanVAE_(nn.Module):
484
+
485
+ def __init__(self,
486
+ dim=128,
487
+ z_dim=4,
488
+ dim_mult=[1, 2, 4, 4],
489
+ num_res_blocks=2,
490
+ attn_scales=[],
491
+ temperal_downsample=[True, True, False],
492
+ dropout=0.0):
493
+ super().__init__()
494
+ self.dim = dim
495
+ self.z_dim = z_dim
496
+ self.dim_mult = dim_mult
497
+ self.num_res_blocks = num_res_blocks
498
+ self.attn_scales = attn_scales
499
+ self.temperal_downsample = temperal_downsample
500
+ self.temperal_upsample = temperal_downsample[::-1]
501
+
502
+ # modules
503
+ self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
504
+ attn_scales, self.temperal_downsample, dropout)
505
+ self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
506
+ self.conv2 = CausalConv3d(z_dim, z_dim, 1)
507
+ self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
508
+ attn_scales, self.temperal_upsample, dropout)
509
+
510
+ def forward(self, x):
511
+ mu, log_var = self.encode(x)
512
+ z = self.reparameterize(mu, log_var)
513
+ x_recon = self.decode(z)
514
+ return x_recon, mu, log_var
515
+
516
+ def encode(self, x, scale):
517
+ self.clear_cache()
518
+ ## cache
519
+ t = x.shape[2]
520
+ iter_ = 1 + (t - 1) // 4
521
+ ## 对encode输入的x,按时间拆分为1、4、4、4....
522
+ for i in range(iter_):
523
+ self._enc_conv_idx = [0]
524
+ if i == 0:
525
+ out = self.encoder(
526
+ x[:, :, :1, :, :],
527
+ feat_cache=self._enc_feat_map,
528
+ feat_idx=self._enc_conv_idx)
529
+ else:
530
+ out_ = self.encoder(
531
+ x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
532
+ feat_cache=self._enc_feat_map,
533
+ feat_idx=self._enc_conv_idx)
534
+ out = torch.cat([out, out_], 2)
535
+ mu, log_var = self.conv1(out).chunk(2, dim=1)
536
+ if isinstance(scale[0], torch.Tensor):
537
+ mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
538
+ 1, self.z_dim, 1, 1, 1)
539
+ else:
540
+ mu = (mu - scale[0]) * scale[1]
541
+ self.clear_cache()
542
+ return mu
543
+
544
+ def decode(self, z, scale):
545
+ self.clear_cache()
546
+ # z: [b,c,t,h,w]
547
+ if isinstance(scale[0], torch.Tensor):
548
+ z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
549
+ 1, self.z_dim, 1, 1, 1)
550
+ else:
551
+ z = z / scale[1] + scale[0]
552
+ iter_ = z.shape[2]
553
+ x = self.conv2(z)
554
+ for i in range(iter_):
555
+ self._conv_idx = [0]
556
+ if i == 0:
557
+ out = self.decoder(
558
+ x[:, :, i:i + 1, :, :],
559
+ feat_cache=self._feat_map,
560
+ feat_idx=self._conv_idx)
561
+ else:
562
+ out_ = self.decoder(
563
+ x[:, :, i:i + 1, :, :],
564
+ feat_cache=self._feat_map,
565
+ feat_idx=self._conv_idx)
566
+ out = torch.cat([out, out_], 2)
567
+ self.clear_cache()
568
+ return out
569
+
570
+ def reparameterize(self, mu, log_var):
571
+ std = torch.exp(0.5 * log_var)
572
+ eps = torch.randn_like(std)
573
+ return eps * std + mu
574
+
575
+ def sample(self, imgs, deterministic=False):
576
+ mu, log_var = self.encode(imgs)
577
+ if deterministic:
578
+ return mu
579
+ std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
580
+ return mu + std * torch.randn_like(std)
581
+
582
+ def clear_cache(self):
583
+ self._conv_num = count_conv3d(self.decoder)
584
+ self._conv_idx = [0]
585
+ self._feat_map = [None] * self._conv_num
586
+ #cache encode
587
+ self._enc_conv_num = count_conv3d(self.encoder)
588
+ self._enc_conv_idx = [0]
589
+ self._enc_feat_map = [None] * self._enc_conv_num
590
+
591
+
592
+ def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs):
593
+ """
594
+ Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
595
+ """
596
+ # params
597
+ cfg = dict(
598
+ dim=96,
599
+ z_dim=z_dim,
600
+ dim_mult=[1, 2, 4, 4],
601
+ num_res_blocks=2,
602
+ attn_scales=[],
603
+ temperal_downsample=[False, True, True],
604
+ dropout=0.0)
605
+ cfg.update(**kwargs)
606
+
607
+ # init model
608
+ # with torch.device('meta'):
609
+ model = WanVAE_(**cfg)
610
+
611
+ # load checkpoint
612
+ logging.info(f'loading {pretrained_path}')
613
+ if pretrained_path is not None:
614
+ model.load_state_dict(torch.load(pretrained_path, map_location=device), assign=True)
615
+
616
+ return model
617
+
618
+
619
+ class WanVAE:
620
+
621
+ def __init__(self,
622
+ z_dim=16,
623
+ vae_pth=None,
624
+ dtype=torch.float,
625
+ device="cuda"):
626
+ self.dtype = dtype
627
+ self.device = device
628
+
629
+ mean = [
630
+ -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
631
+ 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
632
+ ]
633
+ std = [
634
+ 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
635
+ 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
636
+ ]
637
+ self.mean = torch.tensor(mean, dtype=dtype, device=device)
638
+ self.std = torch.tensor(std, dtype=dtype, device=device)
639
+ self.scale = [self.mean, 1.0 / self.std]
640
+
641
+ # init model
642
+ self.model = _video_vae(
643
+ pretrained_path=vae_pth,
644
+ z_dim=z_dim,
645
+ ).eval().requires_grad_(False).to(device)
646
+
647
+ @torch.no_grad()
648
+ def encode(self, videos, device):
649
+ """
650
+ videos: A list of videos each with shape [C, T, H, W].
651
+ """
652
+
653
+ with torch.amp.autocast('cuda', dtype=self.dtype):
654
+ return [
655
+ self.model.encode(u.unsqueeze(0).to(device,self.dtype), self.scale).float().squeeze(0)
656
+ for u in videos
657
+ ]
658
+
659
+ @torch.no_grad()
660
+ def decode(self, zs):
661
+ with torch.amp.autocast('cuda', dtype=self.dtype):
662
+ return [
663
+ self.model.decode(u.unsqueeze(0),
664
+ self.scale).float().clamp_(-1, 1).squeeze(0)
665
+ for u in zs
666
+ ]
humo/models/wan_modules/xlm_roberta.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from transformers.models.xlm_roberta.modeling_xlm_roberta
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ __all__ = ['XLMRoberta', 'xlm_roberta_large']
8
+
9
+
10
+ class SelfAttention(nn.Module):
11
+
12
+ def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
13
+ assert dim % num_heads == 0
14
+ super().__init__()
15
+ self.dim = dim
16
+ self.num_heads = num_heads
17
+ self.head_dim = dim // num_heads
18
+ self.eps = eps
19
+
20
+ # layers
21
+ self.q = nn.Linear(dim, dim)
22
+ self.k = nn.Linear(dim, dim)
23
+ self.v = nn.Linear(dim, dim)
24
+ self.o = nn.Linear(dim, dim)
25
+ self.dropout = nn.Dropout(dropout)
26
+
27
+ def forward(self, x, mask):
28
+ """
29
+ x: [B, L, C].
30
+ """
31
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
32
+
33
+ # compute query, key, value
34
+ q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
35
+ k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
36
+ v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
37
+
38
+ # compute attention
39
+ p = self.dropout.p if self.training else 0.0
40
+ x = F.scaled_dot_product_attention(q, k, v, mask, p)
41
+ x = x.permute(0, 2, 1, 3).reshape(b, s, c)
42
+
43
+ # output
44
+ x = self.o(x)
45
+ x = self.dropout(x)
46
+ return x
47
+
48
+
49
+ class AttentionBlock(nn.Module):
50
+
51
+ def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
52
+ super().__init__()
53
+ self.dim = dim
54
+ self.num_heads = num_heads
55
+ self.post_norm = post_norm
56
+ self.eps = eps
57
+
58
+ # layers
59
+ self.attn = SelfAttention(dim, num_heads, dropout, eps)
60
+ self.norm1 = nn.LayerNorm(dim, eps=eps)
61
+ self.ffn = nn.Sequential(
62
+ nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim),
63
+ nn.Dropout(dropout))
64
+ self.norm2 = nn.LayerNorm(dim, eps=eps)
65
+
66
+ def forward(self, x, mask):
67
+ if self.post_norm:
68
+ x = self.norm1(x + self.attn(x, mask))
69
+ x = self.norm2(x + self.ffn(x))
70
+ else:
71
+ x = x + self.attn(self.norm1(x), mask)
72
+ x = x + self.ffn(self.norm2(x))
73
+ return x
74
+
75
+
76
+ class XLMRoberta(nn.Module):
77
+ """
78
+ XLMRobertaModel with no pooler and no LM head.
79
+ """
80
+
81
+ def __init__(self,
82
+ vocab_size=250002,
83
+ max_seq_len=514,
84
+ type_size=1,
85
+ pad_id=1,
86
+ dim=1024,
87
+ num_heads=16,
88
+ num_layers=24,
89
+ post_norm=True,
90
+ dropout=0.1,
91
+ eps=1e-5):
92
+ super().__init__()
93
+ self.vocab_size = vocab_size
94
+ self.max_seq_len = max_seq_len
95
+ self.type_size = type_size
96
+ self.pad_id = pad_id
97
+ self.dim = dim
98
+ self.num_heads = num_heads
99
+ self.num_layers = num_layers
100
+ self.post_norm = post_norm
101
+ self.eps = eps
102
+
103
+ # embeddings
104
+ self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)
105
+ self.type_embedding = nn.Embedding(type_size, dim)
106
+ self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)
107
+ self.dropout = nn.Dropout(dropout)
108
+
109
+ # blocks
110
+ self.blocks = nn.ModuleList([
111
+ AttentionBlock(dim, num_heads, post_norm, dropout, eps)
112
+ for _ in range(num_layers)
113
+ ])
114
+
115
+ # norm layer
116
+ self.norm = nn.LayerNorm(dim, eps=eps)
117
+
118
+ def forward(self, ids):
119
+ """
120
+ ids: [B, L] of torch.LongTensor.
121
+ """
122
+ b, s = ids.shape
123
+ mask = ids.ne(self.pad_id).long()
124
+
125
+ # embeddings
126
+ x = self.token_embedding(ids) + \
127
+ self.type_embedding(torch.zeros_like(ids)) + \
128
+ self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
129
+ if self.post_norm:
130
+ x = self.norm(x)
131
+ x = self.dropout(x)
132
+
133
+ # blocks
134
+ mask = torch.where(
135
+ mask.view(b, 1, 1, s).gt(0), 0.0,
136
+ torch.finfo(x.dtype).min)
137
+ for block in self.blocks:
138
+ x = block(x, mask)
139
+
140
+ # output
141
+ if not self.post_norm:
142
+ x = self.norm(x)
143
+ return x
144
+
145
+
146
+ def xlm_roberta_large(pretrained=False,
147
+ return_tokenizer=False,
148
+ device='cpu',
149
+ **kwargs):
150
+ """
151
+ XLMRobertaLarge adapted from Huggingface.
152
+ """
153
+ # params
154
+ cfg = dict(
155
+ vocab_size=250002,
156
+ max_seq_len=514,
157
+ type_size=1,
158
+ pad_id=1,
159
+ dim=1024,
160
+ num_heads=16,
161
+ num_layers=24,
162
+ post_norm=True,
163
+ dropout=0.1,
164
+ eps=1e-5)
165
+ cfg.update(**kwargs)
166
+
167
+ # init a model on device
168
+ with torch.device(device):
169
+ model = XLMRoberta(**cfg)
170
+ return model
humo/utils/audio_processor_whisper.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: disable=C0301
2
+ '''
3
+ This module contains the AudioProcessor class and related functions for processing audio data.
4
+ It utilizes various libraries and models to perform tasks such as preprocessing, feature extraction,
5
+ and audio separation. The class is initialized with configuration parameters and can process
6
+ audio files using the provided models.
7
+ '''
8
+ import os
9
+ import subprocess
10
+
11
+ import librosa
12
+ import numpy as np
13
+ import torch
14
+ from audio_separator.separator import Separator
15
+ from transformers import WhisperModel, AutoFeatureExtractor
16
+ import torch.nn.functional as F
17
+
18
+
19
+ def linear_interpolation_fps(features, input_fps, output_fps, output_len=None):
20
+ features = features.transpose(1, 2) # [1, C, T]
21
+ seq_len = features.shape[2] / float(input_fps)
22
+ if output_len is None:
23
+ output_len = int(seq_len * output_fps)
24
+ output_features = F.interpolate(features, size=output_len, align_corners=True, mode='linear')
25
+ return output_features.transpose(1, 2)
26
+
27
+
28
+ def resample_audio(input_audio_file: str, output_audio_file: str, sample_rate: int):
29
+ p = subprocess.Popen([
30
+ "ffmpeg", "-y", "-v", "error", "-i", input_audio_file, "-ar", str(sample_rate), output_audio_file
31
+ ])
32
+ ret = p.wait()
33
+ assert ret == 0, "Resample audio failed!"
34
+ return output_audio_file
35
+
36
+ class AudioProcessor:
37
+ """
38
+ AudioProcessor is a class that handles the processing of audio files.
39
+ It takes care of preprocessing the audio files, extracting features
40
+ using wav2vec models, and separating audio signals if needed.
41
+
42
+ :param sample_rate: Sampling rate of the audio file
43
+ :param fps: Frames per second for the extracted features
44
+ :param wav2vec_model_path: Path to the wav2vec model
45
+ :param only_last_features: Whether to only use the last features
46
+ :param audio_separator_model_path: Path to the audio separator model
47
+ :param audio_separator_model_name: Name of the audio separator model
48
+ :param cache_dir: Directory to cache the intermediate results
49
+ :param device: Device to run the processing on
50
+ """
51
+ def __init__(
52
+ self,
53
+ sample_rate,
54
+ fps,
55
+ wav2vec_model_path,
56
+ wav2vec_feature_type,
57
+ audio_separator_model_path:str=None,
58
+ audio_separator_model_name:str=None,
59
+ cache_dir:str='',
60
+ device="cuda:0",
61
+ ) -> None:
62
+ self.sample_rate = sample_rate
63
+ self.fps = fps
64
+ self.device = device
65
+
66
+ self.whisper = WhisperModel.from_pretrained(wav2vec_model_path).to(device).eval()
67
+ self.whisper.requires_grad_(False)
68
+ self.feature_extractor = AutoFeatureExtractor.from_pretrained(wav2vec_model_path)
69
+
70
+ if audio_separator_model_name is not None:
71
+ try:
72
+ os.makedirs(cache_dir, exist_ok=True)
73
+ except OSError as _:
74
+ print("Fail to create the output cache dir.")
75
+ self.audio_separator = Separator(
76
+ output_dir=cache_dir,
77
+ output_single_stem="vocals",
78
+ model_file_dir=audio_separator_model_path,
79
+ )
80
+ self.audio_separator.load_model(audio_separator_model_name)
81
+ assert self.audio_separator.model_instance is not None, "Fail to load audio separate model."
82
+ else:
83
+ self.audio_separator=None
84
+ print("Use audio directly without vocals seperator.")
85
+
86
+
87
+ def get_audio_feature(self, audio_path):
88
+ audio_input, sampling_rate = librosa.load(audio_path, sr=16000)
89
+ assert sampling_rate == 16000
90
+
91
+ audio_features = []
92
+ window = 750*640
93
+ for i in range(0, len(audio_input), window):
94
+ audio_feature = self.feature_extractor(audio_input[i:i+window],
95
+ sampling_rate=sampling_rate,
96
+ return_tensors="pt",
97
+ ).input_features
98
+ audio_features.append(audio_feature)
99
+ audio_features = torch.cat(audio_features, dim=-1)
100
+ return audio_features, len(audio_input) // 640
101
+
102
+
103
+ def preprocess(self, audio_path: str):
104
+ audio_input, audio_len = self.get_audio_feature(audio_path)
105
+ audio_feature = audio_input.to(self.whisper.device).float()
106
+ window = 3000
107
+ audio_prompts = []
108
+ for i in range(0, audio_feature.shape[-1], window):
109
+ audio_prompt = self.whisper.encoder(audio_feature[:,:,i:i+window], output_hidden_states=True).hidden_states
110
+ audio_prompt = torch.stack(audio_prompt, dim=2)
111
+ audio_prompts.append(audio_prompt)
112
+
113
+ audio_prompts = torch.cat(audio_prompts, dim=1)
114
+ audio_prompts = audio_prompts[:,:audio_len*2]
115
+
116
+ audio_emb = self.audio_emb_enc(audio_prompts, wav_enc_type="whisper")
117
+
118
+ return audio_emb, audio_emb.shape[0]
119
+
120
+ def audio_emb_enc(self, audio_emb, wav_enc_type="whisper"):
121
+ if wav_enc_type == "wav2vec":
122
+ feat_merge = audio_emb
123
+ elif wav_enc_type == "whisper":
124
+ # [1, T, 33, 1280]
125
+ feat0 = linear_interpolation_fps(audio_emb[:, :, 0: 8].mean(dim=2), 50, 25)
126
+ feat1 = linear_interpolation_fps(audio_emb[:, :, 8: 16].mean(dim=2), 50, 25)
127
+ feat2 = linear_interpolation_fps(audio_emb[:, :, 16: 24].mean(dim=2), 50, 25)
128
+ feat3 = linear_interpolation_fps(audio_emb[:, :, 24: 32].mean(dim=2), 50, 25)
129
+ feat4 = linear_interpolation_fps(audio_emb[:, :, 32], 50, 25)
130
+ feat_merge = torch.stack([feat0, feat1, feat2, feat3, feat4], dim=2)[0] # [T, 5, 1280]
131
+ else:
132
+ raise ValueError(f"Unsupported wav_enc_type: {wav_enc_type}")
133
+
134
+ return feat_merge
135
+
136
+ def get_audio_emb_window(self, audio_emb, frame_num, frame0_idx, audio_shift=2):
137
+ zero_audio_embed = torch.zeros((audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device)
138
+ zero_audio_embed_3 = torch.zeros((3, audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device) # device=audio_emb.device
139
+ iter_ = 1 + (frame_num - 1) // 4
140
+ audio_emb_wind = []
141
+ for lt_i in range(iter_):
142
+ if lt_i == 0: # latent_i
143
+ # 提取第一帧VAElatent,audio左侧补0,标识出
144
+ st = frame0_idx + lt_i - 2
145
+ ed = frame0_idx + lt_i + 3
146
+ wind_feat = torch.stack([
147
+ audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed
148
+ for i in range(st, ed)
149
+ ], dim=0) # [5, 13, 768]
150
+ wind_feat = torch.cat((zero_audio_embed_3, wind_feat), dim=0) # [8, 13, 768]
151
+ else:
152
+ st = frame0_idx + 1 + 4 * (lt_i - 1) - audio_shift
153
+ ed = frame0_idx + 1 + 4 * lt_i + audio_shift
154
+ wind_feat = torch.stack([
155
+ audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed
156
+ for i in range(st, ed)
157
+ ], dim=0) # [8, 13, 768]
158
+ audio_emb_wind.append(wind_feat)
159
+ audio_emb_wind = torch.stack(audio_emb_wind, dim=0) # [iter_, 8, 13, 768]
160
+
161
+ return audio_emb_wind, ed - audio_shift
162
+
163
+ def close(self):
164
+ """
165
+ TODO: to be implemented
166
+ """
167
+ return self
168
+
169
+ def __enter__(self):
170
+ return self
171
+
172
+ def __exit__(self, _exc_type, _exc_val, _exc_tb):
173
+ self.close()
humo/utils/wav2vec.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: disable=R0901
2
+ # src/models/wav2vec.py
3
+
4
+ """
5
+ This module defines the Wav2Vec model, which is a pre-trained model for speech recognition and understanding.
6
+ It inherits from the Wav2Vec2Model class in the transformers library and provides additional functionalities
7
+ such as feature extraction and encoding.
8
+
9
+ Classes:
10
+ Wav2VecModel: Inherits from Wav2Vec2Model and adds additional methods for feature extraction and encoding.
11
+
12
+ Functions:
13
+ linear_interpolation: Interpolates the features based on the sequence length.
14
+ """
15
+
16
+ import torch.nn.functional as F
17
+ from transformers import Wav2Vec2Model
18
+ from transformers.modeling_outputs import BaseModelOutput
19
+
20
+
21
+ class Wav2VecModel(Wav2Vec2Model):
22
+ """
23
+ Wav2VecModel is a custom model class that extends the Wav2Vec2Model class from the transformers library.
24
+ It inherits all the functionality of the Wav2Vec2Model and adds additional methods for feature extraction and encoding.
25
+ ...
26
+
27
+ Attributes:
28
+ base_model (Wav2Vec2Model): The base Wav2Vec2Model object.
29
+
30
+ Methods:
31
+ forward(input_values, seq_len, attention_mask=None, mask_time_indices=None
32
+ , output_attentions=None, output_hidden_states=None, return_dict=None):
33
+ Forward pass of the Wav2VecModel.
34
+ It takes input_values, seq_len, and other optional parameters as input and returns the output of the base model.
35
+
36
+ feature_extract(input_values, seq_len):
37
+ Extracts features from the input_values using the base model.
38
+
39
+ encode(extract_features, attention_mask=None, mask_time_indices=None, output_attentions=None, output_hidden_states=None, return_dict=None):
40
+ Encodes the extracted features using the base model and returns the encoded features.
41
+ """
42
+ def forward(
43
+ self,
44
+ input_values,
45
+ seq_len,
46
+ attention_mask=None,
47
+ mask_time_indices=None,
48
+ output_attentions=None,
49
+ output_hidden_states=None,
50
+ return_dict=None,
51
+ ):
52
+ """
53
+ Forward pass of the Wav2Vec model.
54
+
55
+ Args:
56
+ self: The instance of the model.
57
+ input_values: The input values (waveform) to the model.
58
+ seq_len: The sequence length of the input values.
59
+ attention_mask: Attention mask to be used for the model.
60
+ mask_time_indices: Mask indices to be used for the model.
61
+ output_attentions: If set to True, returns attentions.
62
+ output_hidden_states: If set to True, returns hidden states.
63
+ return_dict: If set to True, returns a BaseModelOutput instead of a tuple.
64
+
65
+ Returns:
66
+ The output of the Wav2Vec model.
67
+ """
68
+ self.config.output_attentions = True
69
+
70
+ output_hidden_states = (
71
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
72
+ )
73
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
74
+
75
+ extract_features = self.feature_extractor(input_values)
76
+ extract_features = extract_features.transpose(1, 2)
77
+ extract_features = linear_interpolation(extract_features, seq_len=seq_len)
78
+
79
+ if attention_mask is not None:
80
+ # compute reduced attention_mask corresponding to feature vectors
81
+ attention_mask = self._get_feature_vector_attention_mask(
82
+ extract_features.shape[1], attention_mask, add_adapter=False
83
+ )
84
+
85
+ hidden_states, extract_features = self.feature_projection(extract_features)
86
+ hidden_states = self._mask_hidden_states(
87
+ hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
88
+ )
89
+
90
+ encoder_outputs = self.encoder(
91
+ hidden_states,
92
+ attention_mask=attention_mask,
93
+ output_attentions=output_attentions,
94
+ output_hidden_states=output_hidden_states,
95
+ return_dict=return_dict,
96
+ )
97
+
98
+ hidden_states = encoder_outputs[0]
99
+
100
+ if self.adapter is not None:
101
+ hidden_states = self.adapter(hidden_states)
102
+
103
+ if not return_dict:
104
+ return (hidden_states, ) + encoder_outputs[1:]
105
+ return BaseModelOutput(
106
+ last_hidden_state=hidden_states,
107
+ hidden_states=encoder_outputs.hidden_states,
108
+ attentions=encoder_outputs.attentions,
109
+ )
110
+
111
+
112
+ def feature_extract(
113
+ self,
114
+ input_values,
115
+ seq_len,
116
+ ):
117
+ """
118
+ Extracts features from the input values and returns the extracted features.
119
+
120
+ Parameters:
121
+ input_values (torch.Tensor): The input values to be processed.
122
+ seq_len (torch.Tensor): The sequence lengths of the input values.
123
+
124
+ Returns:
125
+ extracted_features (torch.Tensor): The extracted features from the input values.
126
+ """
127
+ extract_features = self.feature_extractor(input_values)
128
+ extract_features = extract_features.transpose(1, 2)
129
+ extract_features = linear_interpolation(extract_features, seq_len=seq_len)
130
+
131
+ return extract_features
132
+
133
+ def encode(
134
+ self,
135
+ extract_features,
136
+ attention_mask=None,
137
+ mask_time_indices=None,
138
+ output_attentions=None,
139
+ output_hidden_states=None,
140
+ return_dict=None,
141
+ ):
142
+ """
143
+ Encodes the input features into the output space.
144
+
145
+ Args:
146
+ extract_features (torch.Tensor): The extracted features from the audio signal.
147
+ attention_mask (torch.Tensor, optional): Attention mask to be used for padding.
148
+ mask_time_indices (torch.Tensor, optional): Masked indices for the time dimension.
149
+ output_attentions (bool, optional): If set to True, returns the attention weights.
150
+ output_hidden_states (bool, optional): If set to True, returns all hidden states.
151
+ return_dict (bool, optional): If set to True, returns a BaseModelOutput instead of the tuple.
152
+
153
+ Returns:
154
+ The encoded output features.
155
+ """
156
+ self.config.output_attentions = True
157
+
158
+ output_hidden_states = (
159
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
160
+ )
161
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
162
+
163
+ if attention_mask is not None:
164
+ # compute reduced attention_mask corresponding to feature vectors
165
+ attention_mask = self._get_feature_vector_attention_mask(
166
+ extract_features.shape[1], attention_mask, add_adapter=False
167
+ )
168
+
169
+ hidden_states, extract_features = self.feature_projection(extract_features)
170
+ hidden_states = self._mask_hidden_states(
171
+ hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
172
+ )
173
+
174
+ encoder_outputs = self.encoder(
175
+ hidden_states,
176
+ attention_mask=attention_mask,
177
+ output_attentions=output_attentions,
178
+ output_hidden_states=output_hidden_states,
179
+ return_dict=return_dict,
180
+ )
181
+
182
+ hidden_states = encoder_outputs[0]
183
+
184
+ if self.adapter is not None:
185
+ hidden_states = self.adapter(hidden_states)
186
+
187
+ if not return_dict:
188
+ return (hidden_states, ) + encoder_outputs[1:]
189
+ return BaseModelOutput(
190
+ last_hidden_state=hidden_states,
191
+ hidden_states=encoder_outputs.hidden_states,
192
+ attentions=encoder_outputs.attentions,
193
+ )
194
+
195
+
196
+ def linear_interpolation(features, seq_len):
197
+ """
198
+ Transpose the features to interpolate linearly.
199
+
200
+ Args:
201
+ features (torch.Tensor): The extracted features to be interpolated.
202
+ seq_len (torch.Tensor): The sequence lengths of the features.
203
+
204
+ Returns:
205
+ torch.Tensor: The interpolated features.
206
+ """
207
+ features = features.transpose(1, 2)
208
+ output_features = F.interpolate(features, size=seq_len, align_corners=True, mode='linear')
209
+ return output_features.transpose(1, 2)
210
+
211
+
212
+ def linear_interpolation_fps(features, input_fps, output_fps, output_len=None):
213
+ features = features.transpose(1, 2) # [1, C, T]
214
+ seq_len = features.shape[2] / float(input_fps)
215
+ if output_len is None:
216
+ output_len = int(seq_len * output_fps)
217
+ output_features = F.interpolate(features, size=output_len, align_corners=True, mode='linear')
218
+ return output_features.transpose(1, 2)
main.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ # Inference codes adapted from [SeedVR]
13
+ # https://github.com/ByteDance-Seed/SeedVR/blob/main/projects/inference_seedvr2_7b.py
14
+
15
+ from sys import argv
16
+ import sys
17
+
18
+ path_to_insert = "humo"
19
+ if path_to_insert not in sys.path:
20
+ sys.path.insert(0, path_to_insert)
21
+
22
+ from common.config import load_config, create_object
23
+
24
+ # Load config.
25
+ config = load_config(argv[1], argv[2:])
26
+
27
+ runner = create_object(config)
28
+ runner.entrypoint()