diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..c729f041a0c89076070ba075b8e8df90e5ad187b 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,14 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+assets/teaser.png filter=lfs diff=lfs merge=lfs -text
+examples/amber.png filter=lfs diff=lfs merge=lfs -text
+examples/armour.png filter=lfs diff=lfs merge=lfs -text
+examples/art.wav filter=lfs diff=lfs merge=lfs -text
+examples/chris.png filter=lfs diff=lfs merge=lfs -text
+examples/dream.mp3 filter=lfs diff=lfs merge=lfs -text
+examples/fictional.wav filter=lfs diff=lfs merge=lfs -text
+examples/fight.wav filter=lfs diff=lfs merge=lfs -text
+examples/jacket.png filter=lfs diff=lfs merge=lfs -text
+examples/naomi.png filter=lfs diff=lfs merge=lfs -text
+examples/vangogh.jpg filter=lfs diff=lfs merge=lfs -text
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..62ae395d2b41b3006f5f79d5a7539bad424b21a6
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "{}"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2025 Bytedance
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/README.md b/README.md
index c9bc0df86bbd5f3622eb509cf68dca0e9d3289cf..1c090487f44deadc989c3613db1527af22827d6a 100644
--- a/README.md
+++ b/README.md
@@ -1,12 +1,13 @@
----
-title: HuMo Local
-emoji: 🌍
-colorFrom: green
-colorTo: red
-sdk: gradio
-sdk_version: 5.49.1
-app_file: app.py
-pinned: false
----
-
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+---
+title: HuMo [Local]
+emoji: 👩🦱
+colorFrom: purple
+colorTo: gray
+sdk: gradio
+sdk_version: 5.47.2
+app_file: app.py
+pinned: false
+short_description: Reference based video generation
+---
+
+Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
\ No newline at end of file
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f8dfff7e5164c3e2b60234481c9e6a2527b47bc
--- /dev/null
+++ b/app.py
@@ -0,0 +1,435 @@
+import spaces
+import gradio as gr
+import sys
+import os
+import subprocess
+import uuid
+import shutil
+
+
+
+from huggingface_hub import snapshot_download, list_repo_files, hf_hub_download
+import importlib, site
+
+
+# Re-discover all .pth/.egg-link files
+for sitedir in site.getsitepackages():
+ site.addsitedir(sitedir)
+
+# Clear caches so importlib will pick up new modules
+importlib.invalidate_caches()
+
+def sh(cmd): subprocess.check_call(cmd, shell=True)
+
+flash_attention_installed = False
+
+try:
+ flash_attention_wheel = hf_hub_download(
+ repo_id="alexnasa/flash-attn-3",
+ repo_type="model",
+ filename="128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl",
+ )
+
+ sh(f"pip install {flash_attention_wheel}")
+ print("Attempting to download and install FlashAttention wheel...")
+ # sh("pip install flash-attn")
+ sh("pip install --no-build-isolation transformer_engine-2.5.0+f05f12c9-cp310-cp310-linux_x86_64.whl")
+
+ # tell Python to re-scan site-packages now that the egg-link exists
+ import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.invalidate_caches()
+
+ flash_attention_installed = True
+
+except Exception as e:
+ print(f"⚠️ Could not install FlashAttention: {e}")
+ print("Continuing without FlashAttention...")
+
+try:
+ te_wheel = hf_hub_download(
+ repo_id="alexnasa/transformer_engine_wheels",
+ repo_type="model",
+ filename="transformer_engine-2.5.0+f05f12c9-cp310-cp310-linux_x86_64.whl",
+ )
+
+ sh(f"pip install {te_wheel}")
+ print("Attempting to download and install Transformer Engine wheel...")
+
+ # tell Python to re-scan site-packages now that the egg-link exists
+ import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.invalidate_caches()
+
+except Exception as e:
+ print(f"⚠️ Could not install Transformer Engine : {e}")
+ print("Continuing without Transformer Engine ...")
+
+import torch
+print(f"Torch version: {torch.__version__}")
+print(f"FlashAttention available: {flash_attention_installed}")
+
+import tempfile
+from pathlib import Path
+from torch._inductor.runtime.runtime_utils import cache_dir as _inductor_cache_dir
+from huggingface_hub import HfApi
+
+
+snapshot_download(repo_id="bytedance-research/HuMo", local_dir="./weights/HuMo")
+snapshot_download(repo_id="Wan-AI/Wan2.1-T2V-1.3B", local_dir="./weights/Wan2.1-T2V-1.3B")
+snapshot_download(repo_id="openai/whisper-large-v3", local_dir="./weights/whisper-large-v3")
+
+os.environ["PROCESSED_RESULTS"] = f"{os.getcwd()}/proprocess_results"
+
+path_to_insert = "humo"
+if path_to_insert not in sys.path:
+ sys.path.insert(0, path_to_insert)
+
+from common.config import load_config, create_object
+
+config = load_config(
+ "./humo/configs/inference/generate.yaml",
+ [
+ "dit.sp_size=1",
+ "generation.frames=97",
+ "generation.scale_t=5.5",
+ "generation.scale_a=5.0",
+ "generation.mode=TIA",
+ "generation.height=480",
+ "generation.width=832",
+ ],
+)
+runner = create_object(config)
+
+
+os.environ.setdefault("TORCHINDUCTOR_CACHE_DIR", f"{os.getcwd()}/torchinductor_space") # or another writable path
+
+def restore_inductor_cache_from_hub(repo_id: str, filename: str = "torch_compile_cache.zip",
+ path_in_repo: str = "inductor_cache", repo_type: str = "model",
+ hf_token: str | None = None):
+ cache_root = Path(_inductor_cache_dir()).resolve()
+ cache_root.mkdir(parents=True, exist_ok=True)
+ zip_path = hf_hub_download(repo_id=repo_id, filename=f"{path_in_repo}/{filename}",
+ repo_type=repo_type, token=hf_token)
+ shutil.unpack_archive(zip_path, extract_dir=str(cache_root))
+ print(f"✓ Restored cache into {cache_root}")
+
+
+# restore_inductor_cache_from_hub("alexnasa/humo-compiled")
+
+
+def get_duration(prompt_text, steps, image_file, audio_file_path, tea_cache_l1_thresh, max_duration, session_id):
+
+ return calculate_required_time(steps, max_duration)
+
+def calculate_required_time(steps, max_duration):
+
+ warmup_s = 60
+
+ max_duration_duration_mapping = {
+ 1: 8,
+ 2: 8,
+ 3: 11,
+ 4: 20,
+ 5: 30,
+ }
+ each_step_s = max_duration_duration_mapping[max_duration]
+ duration_s = (each_step_s * steps) + warmup_s
+
+ print(f'estimated duration:{duration_s}')
+
+ return int(duration_s)
+
+def get_required_time_string(steps, max_duration):
+
+ duration_s = calculate_required_time(steps, max_duration)
+ duration_m = duration_s / 60
+
+ return f"
⌚ Zero GPU Required: ~{duration_s}.0s ({duration_m:.1f} mins)"
+
+def update_required_time(steps, max_duration):
+
+ return get_required_time_string(steps, max_duration)
+
+
+def generate_scene(prompt_text, steps, image_paths, audio_file_path, tea_cache_l1_thresh, max_duration = 2, session_id = None):
+
+ print(image_paths)
+ prompt_text_check = (prompt_text or "").strip()
+ if not prompt_text_check:
+ raise gr.Error("Please enter a prompt.")
+
+ if not audio_file_path and not image_paths:
+ raise gr.Error("Please provide a reference image or a lipsync audio.")
+
+ return run_pipeline(prompt_text, steps, image_paths, audio_file_path, tea_cache_l1_thresh, max_duration, session_id)
+
+
+
+def upload_inductor_cache_to_hub(
+ repo_id: str,
+ path_in_repo: str = "inductor_cache",
+ repo_type: str = "model", # or "dataset" if you prefer
+ hf_token: str | None = None,
+):
+ """
+ Zips the current TorchInductor cache and uploads it to the given repo path.
+ Assumes the model was already run once with torch.compile() so the cache exists.
+ """
+
+ cache_dir = Path(_inductor_cache_dir()).resolve()
+ if not cache_dir.exists():
+ raise FileNotFoundError(f"TorchInductor cache not found at {cache_dir}. "
+ "Run a compiled model once to populate it.")
+
+ # Create a zip archive of the entire cache directory
+ with tempfile.TemporaryDirectory() as tmpdir:
+ archive_base = Path(tmpdir) / "torch_compile_cache"
+ archive_path = shutil.make_archive(str(archive_base), "zip", root_dir=str(cache_dir))
+ archive_path = Path(archive_path)
+
+ # Upload to Hub
+ api = HfApi(token=hf_token)
+ api.create_repo(repo_id=repo_id, repo_type=repo_type, exist_ok=True)
+ # Put each artifact under path_in_repo, including a tiny metadata stamp for traceability
+ # Upload the zip
+ dest_path = f"{path_in_repo}/{archive_path.name}"
+ api.upload_file(
+ path_or_fileobj=str(archive_path),
+ path_in_repo=dest_path,
+ repo_id=repo_id,
+ repo_type=repo_type,
+ )
+ # Upload a small metadata file (optional but handy)
+ meta_txt = (
+ f"pytorch={torch.__version__}\n"
+ f"inductor_cache_dir={cache_dir}\n"
+ f"cuda_available={torch.cuda.is_available()}\n"
+ f"cuda_device={torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'cpu'}\n"
+ )
+ api.upload_file(
+ path_or_fileobj=meta_txt.encode(),
+ path_in_repo=f"{path_in_repo}/INDUCTOR_CACHE_METADATA.txt",
+ repo_id=repo_id,
+ repo_type=repo_type,
+ )
+
+ print("✔ Uploaded TorchInductor cache to the Hub.")
+
+
+@spaces.GPU(duration=get_duration)
+def run_pipeline(prompt_text, steps, image_paths, audio_file_path, tea_cache_l1_thresh = 0.0, max_duration = 2, session_id = None):
+
+ if session_id is None:
+ session_id = uuid.uuid4().hex
+
+ inference_mode = "TIA"
+
+ # Validate inputs
+ prompt_text = (prompt_text or "").strip()
+ if not prompt_text:
+ raise gr.Error("Please enter a prompt.")
+
+ if not audio_file_path and not image_paths:
+ raise gr.Error("Please provide a reference image or a lipsync audio.")
+
+ if not audio_file_path:
+ inference_mode = "TI"
+ audio_path = None
+ else:
+ audio_path = audio_file_path if isinstance(audio_file_path, str) else getattr(audio_file_path, "name", str(audio_file_path))
+
+ if not image_paths:
+ inference_mode = "TA"
+ img_paths = None
+ else:
+ img_paths = [image_data[0] for image_data in image_paths]
+
+
+ # Prepare output
+ output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
+ os.makedirs(output_dir, exist_ok=True)
+
+ # Random filename
+ filename = f"gen_{uuid.uuid4().hex[:10]}"
+ width, height = 832, 480
+
+ duration_frame_mapping = {
+ 1:25,
+ 2:45,
+ 3:70,
+ 4:97,
+ 5:129
+ }
+
+ # Run inference
+ runner.inference_loop(
+ prompt_text,
+ img_paths,
+ audio_path,
+ output_dir,
+ filename,
+ inference_mode,
+ width,
+ height,
+ steps,
+ frames = int(duration_frame_mapping[max_duration]),
+ tea_cache_l1_thresh = tea_cache_l1_thresh,
+ )
+
+ # Return resulting video path
+ video_path = os.path.join(output_dir, f"{filename}.mp4")
+ if os.path.exists(video_path):
+
+ # upload_inductor_cache_to_hub("alexnasa/humo-compiled")
+
+ return video_path
+ else:
+ candidates = [os.path.join(output_dir, f) for f in os.listdir(output_dir) if f.endswith(".mp4")]
+ if candidates:
+ return max(candidates, key=lambda p: os.path.getmtime(p))
+ return None
+
+css = """
+ #col-container {
+ margin: 0 auto;
+ width: 100%;
+ max-width: 720px;
+ }
+ """
+
+def cleanup(request: gr.Request):
+
+ sid = request.session_hash
+ if sid:
+ d1 = os.path.join(os.environ["PROCESSED_RESULTS"], sid)
+ shutil.rmtree(d1, ignore_errors=True)
+
+def start_session(request: gr.Request):
+
+ return request.session_hash
+
+with gr.Blocks(css=css) as demo:
+
+ session_state = gr.State()
+ demo.load(start_session, outputs=[session_state])
+
+ with gr.Sidebar(width=400):
+
+
+ gr.HTML(
+ """
+
+
+ HuMo – Human-Centric Video Generation via Collaborative Multi-Modal Conditioning
+
+
+ [Github]
+
+
+ """
+ )
+
+ gr.Markdown("**REFERENCE IMAGES**")
+
+ img_input = gr.Gallery(
+ show_label=False,
+ label="",
+ interactive=True,
+ rows=1, columns=3, object_fit="contain", height="280",
+ file_types=['image']
+ )
+
+ gr.Markdown("**LIPSYNC AUDIO**")
+
+ audio_input = gr.Audio(
+ sources=["upload"],
+ show_label=False,
+ type="filepath",
+ )
+
+ gr.Markdown("**SETTINGS**")
+
+ default_steps = 10
+ default_max_duration = 2
+
+ max_duration = gr.Slider(minimum=2, maximum=5, value=default_max_duration, step=1, label="Max Duration")
+ steps_input = gr.Slider(minimum=5, maximum=50, value=default_steps, step=5, label="Diffusion Steps")
+ tea_cache_l1_thresh = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.01, label="Cache", visible=False)
+
+
+
+ with gr.Column(elem_id="col-container"):
+
+ gr.HTML(
+ """
+
+
HF Space by:
+
+
+
+
+ """
+ )
+
+ video_output = gr.Video(show_label=False)
+
+ gr.Markdown("PROMPT
")
+
+ prompt_tb = gr.Textbox(
+ show_label=False,
+ lines=5,
+ placeholder="Describe the scene and the person talking....",
+ )
+
+ gr.Markdown("")
+ time_required = gr.Markdown(get_required_time_string(default_steps, default_max_duration))
+ run_btn = gr.Button("🎬 Action", variant="primary")
+
+ gr.Examples(
+ examples=[
+
+ [
+ "A handheld tracking shot follows a female warrior walking through a cave. Her determined eyes are locked straight ahead. She speaks with intensity.",
+ 5,
+ ["./examples/naomi.png"],
+ "./examples/dream.mp3",
+ ],
+
+ [
+ "A reddish-brown haired and bearded man sits pensively against swirling blue-and-white brushstrokes, dressed in a blue coat and dark waistcoat. The artistic backdrop and his thoughtful pose evoke a Post-Impressionist style in a studio-like setting.",
+ 10,
+ ["./examples/vangogh.jpg"],
+ "./examples/art.wav",
+ ],
+
+ [
+ "A handheld tracking shot follows a female through a science lab. Her determined eyes are locked straight ahead. The clip is in black and white and patchy as she is explaining something to someone standing opposite her",
+ 10,
+ ["./examples/naomi.png"],
+ "./examples/science.wav",
+ ],
+
+ [
+ "A woman with long, wavy dark hair looking at a person sitting opposite her whilst holding a book, wearing a leather jacket, long-sleeved jacket with a semi purple color one seen on a photo. Warm, window-like light bathes her figure, highlighting the outfit's elegant design and her graceful movements.",
+ 50,
+ ["./examples/amber.png", "./examples/jacket.png"],
+ "./examples/fictional.mp3",
+ ],
+
+ ],
+ inputs=[prompt_tb, steps_input, img_input, audio_input],
+ outputs=[video_output],
+ fn=run_pipeline,
+ cache_examples=True,
+ )
+ max_duration.change(update_required_time, [steps_input, max_duration], time_required)
+ steps_input.change(update_required_time, [steps_input, max_duration], time_required)
+
+ run_btn.click(
+ fn=generate_scene,
+ inputs=[prompt_tb, steps_input, img_input, audio_input, tea_cache_l1_thresh, max_duration, session_state],
+ outputs=[video_output],
+ )
+
+
+if __name__ == "__main__":
+ demo.unload(cleanup)
+ demo.queue()
+ demo.launch(ssr_mode=False)
diff --git a/assets/teaser.png b/assets/teaser.png
new file mode 100644
index 0000000000000000000000000000000000000000..906c01b24258d9cbea17374f129eec02f68ada6c
--- /dev/null
+++ b/assets/teaser.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:722d29d27fb89a6e1ebebef233492f0c06c25b09a0bdb8e723ef567e778bcf34
+size 5832395
diff --git a/examples/amber.png b/examples/amber.png
new file mode 100644
index 0000000000000000000000000000000000000000..be254a66a54f637325bd46a5b43c01f32770c58e
--- /dev/null
+++ b/examples/amber.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6ce1a891ea71b184eeb4bc768322006c4ecccc8e063b2a0afd38829c6e975f03
+size 2396702
diff --git a/examples/armour.png b/examples/armour.png
new file mode 100644
index 0000000000000000000000000000000000000000..0fa22b591e19170c6fcb976340798541ff31c983
--- /dev/null
+++ b/examples/armour.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:192dd4b1c80c9ddacb8678962b5a1c04855d44c9877810aba032761ce50052a2
+size 1790470
diff --git a/examples/art.wav b/examples/art.wav
new file mode 100644
index 0000000000000000000000000000000000000000..f5ecf0684d1849be5c43d623e9249ee168a2f5f0
--- /dev/null
+++ b/examples/art.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:72c75df8e93a107e262ea9b002a66e72d3c1cd2084bce1474a31d8afffd0b651
+size 114254
diff --git a/examples/chris.png b/examples/chris.png
new file mode 100644
index 0000000000000000000000000000000000000000..47e1466a43fa607d689a0ca94c5f3c1956511b14
--- /dev/null
+++ b/examples/chris.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a3100088e3247d8ecaf1ded2e2417e70d6ad34d24ce7f6e7551cb7fd24c91dcf
+size 2053453
diff --git a/examples/dream.mp3 b/examples/dream.mp3
new file mode 100644
index 0000000000000000000000000000000000000000..6dfc3d08efa1132d14863bc64781efc818e5c532
--- /dev/null
+++ b/examples/dream.mp3
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:27248fd9e8f29bd60ccb1163b8df3c6f2630734f358aa3362ffe67e8148e0eb1
+size 108275
diff --git a/examples/fictional.wav b/examples/fictional.wav
new file mode 100644
index 0000000000000000000000000000000000000000..27bb74c78eb069d23b132b6fa9de7461ad2db58b
--- /dev/null
+++ b/examples/fictional.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:31b550e6433ea44a0642dee90c326664ff4f568fec184170001f834597b3ad23
+size 167084
diff --git a/examples/fight.wav b/examples/fight.wav
new file mode 100644
index 0000000000000000000000000000000000000000..5763aaaf7af9d8056a3ade8b25ca5e5e52c031be
--- /dev/null
+++ b/examples/fight.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8dbee86c85e992ac6d17820a3730bf753fc9bf5bac6b8a470f84b7e98a64221a
+size 264782
diff --git a/examples/jacket.png b/examples/jacket.png
new file mode 100644
index 0000000000000000000000000000000000000000..6530268bff18c6d841bb492f906ef500a716cbf4
--- /dev/null
+++ b/examples/jacket.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e80a02659148e3364eaa46e46dd64a268f04c7f7eeed0e8f203b6b848738666a
+size 2565494
diff --git a/examples/naomi.png b/examples/naomi.png
new file mode 100644
index 0000000000000000000000000000000000000000..c9d117f403585c16157f3f06852f538458ff0a10
--- /dev/null
+++ b/examples/naomi.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5666cd6253658e76695e8529b28d730c74f9e63a1afca07e47505e59b24e7656
+size 1177240
diff --git a/examples/science.wav b/examples/science.wav
new file mode 100644
index 0000000000000000000000000000000000000000..6486ce5d79678c56b3d6dea9174e02b64a0dac8e
Binary files /dev/null and b/examples/science.wav differ
diff --git a/examples/vangogh.jpg b/examples/vangogh.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..e5441c8aa1deac4b12023da5001f6034682c9e8c
--- /dev/null
+++ b/examples/vangogh.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1ae77da89271f32196ad9e8a915e20a7f71e9a84b78ecec3aa1dfcb4e4b39de1
+size 138650
diff --git a/humo/common/__init__.py b/humo/common/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/humo/common/config.py b/humo/common/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea65b5da331eb2acfc7c26d7146ea01cabbb0c9d
--- /dev/null
+++ b/humo/common/config.py
@@ -0,0 +1,107 @@
+# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Codes adapted from [SeedVR]
+# https://github.com/ByteDance-Seed/SeedVR/blob/main/common/config.py
+
+"""
+Configuration utility functions
+"""
+
+import importlib
+from typing import Any, Callable, List, Union
+from omegaconf import DictConfig, ListConfig, OmegaConf
+
+OmegaConf.register_new_resolver("eval", eval)
+
+
+def load_config(path: str, argv: List[str] = None) -> Union[DictConfig, ListConfig]:
+ """
+ Load a configuration. Will resolve inheritance.
+ """
+ config = OmegaConf.load(path)
+ if argv is not None:
+ config_argv = OmegaConf.from_dotlist(argv)
+ config = OmegaConf.merge(config, config_argv)
+ config = resolve_recursive(config, resolve_inheritance)
+ return config
+
+
+def resolve_recursive(
+ config: Any,
+ resolver: Callable[[Union[DictConfig, ListConfig]], Union[DictConfig, ListConfig]],
+) -> Any:
+ config = resolver(config)
+ if isinstance(config, DictConfig):
+ for k in config.keys():
+ v = config.get(k)
+ if isinstance(v, (DictConfig, ListConfig)):
+ config[k] = resolve_recursive(v, resolver)
+ if isinstance(config, ListConfig):
+ for i in range(len(config)):
+ v = config.get(i)
+ if isinstance(v, (DictConfig, ListConfig)):
+ config[i] = resolve_recursive(v, resolver)
+ return config
+
+
+def resolve_inheritance(config: Union[DictConfig, ListConfig]) -> Any:
+ """
+ Recursively resolve inheritance if the config contains:
+ __inherit__: path/to/parent.yaml.
+ """
+ if isinstance(config, DictConfig):
+ inherit = config.pop("__inherit__", None)
+ if inherit:
+ assert isinstance(inherit, str)
+ inherit = load_config(inherit)
+ if len(config.keys()) > 0:
+ config = OmegaConf.merge(inherit, config)
+ else:
+ config = inherit
+ return config
+
+
+def import_item(path: str, name: str) -> Any:
+ """
+ Import a python item. Example: import_item("path.to.file", "MyClass") -> MyClass
+ """
+ return getattr(importlib.import_module(path), name)
+
+
+def create_object(config: DictConfig) -> Any:
+ """
+ Create an object from config.
+ The config is expected to contains the following:
+ __object__:
+ path: path.to.module
+ name: MyClass
+ args: as_config | as_params (default to as_config)
+ """
+ item = import_item(
+ path=config.__object__.path,
+ name=config.__object__.name,
+ )
+ args = config.__object__.get("args", "as_config")
+ if args == "as_config":
+ return item(config)
+ if args == "as_params":
+ config = OmegaConf.to_object(config)
+ config.pop("__object__")
+ return item(**config)
+ raise NotImplementedError(f"Unknown args type: {args}")
+
+
+def create_dataset(path: str, *args, **kwargs) -> Any:
+ """
+ Create a dataset. Requires the file to contain a "create_dataset" function.
+ """
+ return import_item(path, "create_dataset")(*args, **kwargs)
diff --git a/humo/common/distributed/__init__.py b/humo/common/distributed/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..21e58bc73d36cf0d6b9bdb818dd3f504c04ebbe4
--- /dev/null
+++ b/humo/common/distributed/__init__.py
@@ -0,0 +1,41 @@
+# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Codes adapted from [SeedVR]
+# https://github.com/ByteDance-Seed/SeedVR/tree/main/common/distributed
+
+"""
+Distributed package.
+"""
+
+from .basic import (
+ barrier_if_distributed,
+ convert_to_ddp,
+ get_device,
+ get_global_rank,
+ get_local_rank,
+ get_world_size,
+ init_torch,
+ meta_param_init_fn,
+ meta_non_persistent_buffer_init_fn
+)
+
+__all__ = [
+ "barrier_if_distributed",
+ "convert_to_ddp",
+ "get_device",
+ "get_global_rank",
+ "get_local_rank",
+ "get_world_size",
+ "init_torch",
+ "meta_param_init_fn",
+ "meta_non_persistent_buffer_init_fn",
+]
diff --git a/humo/common/distributed/advanced.py b/humo/common/distributed/advanced.py
new file mode 100644
index 0000000000000000000000000000000000000000..e349f5c6674028e932145af4b3279766969186e8
--- /dev/null
+++ b/humo/common/distributed/advanced.py
@@ -0,0 +1,484 @@
+# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Codes adapted from [SeedVR]
+# https://github.com/ByteDance-Seed/SeedVR/tree/main/common/distributed
+
+"""
+Advanced distributed functions for sequence parallel.
+"""
+
+import torch
+from typing import Any, List, Optional, Tuple, Union
+import torch.distributed as dist
+from torch import Tensor
+
+from .basic import get_global_rank, get_world_size
+
+
+_DATA_PARALLEL_GROUP = None
+_SEQUENCE_PARALLEL_GROUP = None
+_SEQUENCE_PARALLEL_CPU_GROUP = None
+
+
+_CFG_PARALLEL_GROUP = None
+_CFG_PARALLEL_CPU_GROUP = None
+
+def get_data_parallel_group() -> Optional[dist.ProcessGroup]:
+ """
+ Get data parallel process group.
+ """
+ return _DATA_PARALLEL_GROUP
+
+
+def get_sequence_parallel_group() -> Optional[dist.ProcessGroup]:
+ """
+ Get sequence parallel process group.
+ """
+ return _SEQUENCE_PARALLEL_GROUP
+
+
+def get_sequence_parallel_cpu_group() -> Optional[dist.ProcessGroup]:
+ """
+ Get sequence parallel CPU process group.
+ """
+ return _SEQUENCE_PARALLEL_CPU_GROUP
+
+
+def get_data_parallel_rank() -> int:
+ """
+ Get data parallel rank.
+ """
+ group = get_data_parallel_group()
+ return dist.get_rank(group) if group else get_global_rank()
+
+
+def get_data_parallel_world_size() -> int:
+ """
+ Get data parallel world size.
+ """
+ group = get_data_parallel_group()
+ return dist.get_world_size(group) if group else get_world_size()
+
+
+def get_sequence_parallel_rank() -> int:
+ """
+ Get sequence parallel rank.
+ """
+ group = get_sequence_parallel_group()
+ return dist.get_rank(group) if group else 0
+
+
+def get_sequence_parallel_world_size() -> int:
+ """
+ Get sequence parallel world size.
+ """
+ group = get_sequence_parallel_group()
+ return dist.get_world_size(group) if group else 1
+
+
+def init_unified_parallel(unified_parallel_size):
+ global _SEQUENCE_PARALLEL_GROUP
+ global _SEQUENCE_PARALLEL_CPU_GROUP
+
+ if unified_parallel_size == 1:
+ return
+
+ assert dist.is_initialized()
+ world_size = dist.get_world_size()
+ rank = dist.get_rank()
+ assert world_size % unified_parallel_size == 0
+ data_parallel_size = world_size // unified_parallel_size
+
+ for i in range(data_parallel_size):
+ # build unified parallel group
+ start_rank = i * unified_parallel_size
+ end_rank = start_rank + unified_parallel_size
+ unified_parallel_ranks = range(start_rank, end_rank)
+ unified_parallel_group = dist.new_group(unified_parallel_ranks)
+ unified_parallel_cpu_group = dist.new_group(unified_parallel_ranks, backend="gloo")
+ if rank in unified_parallel_ranks:
+ _SEQUENCE_PARALLEL_GROUP = unified_parallel_group
+ _SEQUENCE_PARALLEL_CPU_GROUP = unified_parallel_cpu_group
+
+
+def get_unified_parallel_group():
+ global _SEQUENCE_PARALLEL_GROUP
+ return _SEQUENCE_PARALLEL_GROUP
+
+
+def get_unified_parallel_cpu_group():
+ global _SEQUENCE_PARALLEL_CPU_GROUP
+ return _SEQUENCE_PARALLEL_CPU_GROUP
+
+
+def get_unified_parallel_rank():
+ group = get_unified_parallel_group()
+ return dist.get_rank(group) if group else 0
+
+
+def get_unified_parallel_world_size():
+ group = get_unified_parallel_group()
+ return dist.get_world_size(group) if group else 1
+
+
+def is_unified_parallel_initialized():
+ group = get_unified_parallel_group()
+ return group is not None
+
+
+def pad_tensor(x: Tensor, dim: int, padding_size: int):
+ shape = list(x.shape)
+ shape[dim] = padding_size
+ pad = torch.zeros(shape, dtype=x.dtype, device=x.device)
+ return torch.cat([x, pad], dim=dim)
+
+
+class Slice(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx: Any, group: dist.ProcessGroup, local_input: Tensor, dim: int, scale_grad: bool) -> Tensor:
+ ctx.group = group
+ ctx.rank = dist.get_rank(group)
+ seq_world_size = dist.get_world_size(group)
+ ctx.seq_world_size = seq_world_size
+ ctx.dim = dim
+ ctx.scale_grad = scale_grad
+ dim_size = local_input.shape[dim]
+ if not ctx.group:
+ return local_input
+ return local_input.split(dim_size // seq_world_size, dim=dim)[ctx.rank].contiguous()
+
+ @staticmethod
+ def backward(ctx: Any, grad_output: Tensor) -> Tuple[None, Tensor, None]:
+ if not ctx.group:
+ return None, grad_output, None, None
+ dim_size = list(grad_output.size())
+ split_size = dim_size[0]
+ dim_size[0] = dim_size[0] * ctx.seq_world_size
+ output = torch.empty(dim_size, dtype=grad_output.dtype, device=torch.cuda.current_device())
+ dist.all_gather_into_tensor(output, grad_output, group=ctx.group)
+ if ctx.scale_grad:
+ output = output / ctx.seq_world_size
+ return (None, torch.cat(output.split(split_size), dim=ctx.dim), None, None)
+
+
+def gather_outputs(
+ x: Tensor,
+ gather_dim: int,
+ padding_dim: Optional[int] = None,
+ unpad_dim_size: Optional[int] = None,
+ scale_grad=True,
+):
+ """
+ A func to gather the outputs for the model result in sequence parallel
+ """
+ group = get_unified_parallel_group()
+ if not group:
+ return x
+ x = Gather.apply(group, x, gather_dim, scale_grad)
+ if padding_dim is not None:
+ x = unpadding_tensor_for_seqeunce_parallel(x, padding_dim, unpad_dim_size)
+ return x
+
+
+def unpadding_tensor_for_seqeunce_parallel(x: Tensor, dim: int, unpadded_dim_size: int):
+ """
+ A func to remove the padding part of the tensor based on its original shape
+ """
+ group = get_unified_parallel_group()
+ if group is None:
+ return x
+ sp_world = get_unified_parallel_world_size()
+ if unpadded_dim_size % sp_world == 0:
+ return x
+ padding_size = sp_world - (unpadded_dim_size % sp_world)
+ assert (padding_size + unpadded_dim_size) % sp_world == 0
+ return unpad_tensor(x, dim=dim, padding_size=padding_size)
+
+
+def gather_seq_scatter_heads_qkv(
+ qkv_tensor: Tensor,
+ seq_dim: int,
+ unpadded_dim_size: Optional[int] = None,
+ restore_shape: bool = True,
+ async_op: bool = False,
+):
+ """
+ A func to sync splited qkv tensor
+ qkv_tensor: the tensor we want to do alltoall with. The last dim must
+ be the projection_idx, which we will split into 3 part. After
+ spliting, the gather idx will be projecttion_idx + 1
+ seq_dim: gather_dim for all2all comm
+ restore_shape: if True, output will has the same shape length as input
+ """
+ group = get_unified_parallel_group()
+ if not group:
+ return qkv_tensor
+ world = get_unified_parallel_world_size()
+ orig_shape = qkv_tensor.shape
+ scatter_dim = qkv_tensor.dim()
+ bef_all2all_shape = list(orig_shape)
+ qkv_proj_dim = bef_all2all_shape[-1]
+ bef_all2all_shape = bef_all2all_shape[:-1] + [3, qkv_proj_dim // 3]
+ qkv_tensor = qkv_tensor.view(bef_all2all_shape)
+ if async_op:
+ return SeqAllToAll.apply(group, qkv_tensor, scatter_dim, seq_dim, async_op)
+ else:
+ qkv_tensor = SeqAllToAll.apply(group, qkv_tensor, scatter_dim, seq_dim, async_op)
+
+ if restore_shape:
+ out_shape = list(orig_shape)
+ out_shape[seq_dim] *= world
+ out_shape[-1] = qkv_proj_dim // world
+ qkv_tensor = qkv_tensor.view(out_shape)
+
+ # remove padding
+ if unpadded_dim_size and unpadded_dim_size % world != 0:
+ padding_size = qkv_tensor.size(seq_dim) - unpadded_dim_size
+ qkv_tensor = unpad_tensor(qkv_tensor, seq_dim, padding_size)
+
+ return qkv_tensor
+
+
+def gather_seq_scatter_double_head(
+ qkv_tensor: Tensor,
+ seq_dim: int,
+ unpadded_dim_size: Optional[int] = None,
+ restore_shape: bool = True,
+ async_op: bool = False,
+):
+ """
+ A func to sync splited qkv tensor
+ qkv_tensor: the tensor we want to do alltoall with. The last dim must
+ be the projection_idx, which we will split into 3 part. After
+ spliting, the gather idx will be projecttion_idx + 1
+ seq_dim: gather_dim for all2all comm
+ restore_shape: if True, output will has the same shape length as input
+ """
+ qkv1_shape = qkv_tensor.shape
+ group = get_unified_parallel_group()
+ if not group:
+ return qkv_tensor
+ world = get_unified_parallel_world_size()
+ orig_shape = qkv_tensor.shape
+ scatter_dim = qkv_tensor.dim()
+ bef_all2all_shape = list(orig_shape)
+ qkv_proj_dim = bef_all2all_shape[-1]
+ bef_all2all_shape = bef_all2all_shape[:-1] + [2, qkv_proj_dim // 2]
+ qkv_tensor = qkv_tensor.view(bef_all2all_shape)
+ qkv2_shape = qkv_tensor.shape
+ if async_op:
+ return SeqAllToAll.apply(group, qkv_tensor, scatter_dim, seq_dim, async_op)
+ else:
+ qkv_tensor = SeqAllToAll.apply(group, qkv_tensor, scatter_dim, seq_dim, async_op)
+ qkv3_shape = qkv_tensor.shape
+
+ if restore_shape:
+ out_shape = list(orig_shape)
+ out_shape[seq_dim] *= world
+ out_shape[-1] = qkv_proj_dim // world
+ qkv_tensor = qkv_tensor.view(out_shape)
+ qkv4_shape = qkv_tensor.shape
+
+ # remove padding
+ if unpadded_dim_size and unpadded_dim_size % world != 0:
+ padding_size = qkv_tensor.size(seq_dim) - unpadded_dim_size
+ qkv_tensor = unpad_tensor(qkv_tensor, seq_dim, padding_size)
+ qkv5_shape = qkv_tensor.shape
+
+ return qkv_tensor
+
+
+class SeqAllToAll(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx: Any,
+ group: dist.ProcessGroup,
+ local_input: Tensor,
+ scatter_dim: int,
+ gather_dim: int,
+ async_op: bool,
+ ) -> Tensor:
+ ctx.group = group
+ ctx.scatter_dim = scatter_dim
+ ctx.gather_dim = gather_dim
+ ctx.async_op = async_op
+ return all_to_all_tensor(local_input, scatter_dim, gather_dim, group, async_op)
+
+ @staticmethod
+ def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]:
+ if ctx.async_op:
+ input_t = torch.cat(grad_output[1:], dim=ctx.gather_dim).contiguous()
+ else:
+ input_t = grad_output[0]
+ return (
+ None,
+ all_to_all_tensor(input_t, ctx.gather_dim, ctx.scatter_dim, ctx.group, False),
+ None,
+ None,
+ None,
+ None,
+ )
+
+
+def all_to_all_tensor(
+ x: Tensor,
+ scatter_dim: int,
+ gather_dim: int,
+ group: dist.ProcessGroup,
+ async_op: bool = False,
+):
+ if scatter_dim <= 1 and gather_dim <= 1:
+ return _all_to_all_single(x, scatter_dim, gather_dim, group, async_op)
+ else:
+ return _all_to_all(x, scatter_dim, gather_dim, group, async_op) # 走这里
+
+
+def _all_to_all(
+ local_input: Tensor,
+ scatter_dim: int,
+ gather_dim: int,
+ group: dist.ProcessGroup,
+ async_op: bool = False,
+):
+ seq_world_size = dist.get_world_size(group)
+ input_list = [t.contiguous() for t in torch.tensor_split(local_input, seq_world_size, scatter_dim)]
+ output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)]
+ comm = dist.all_to_all(output_list, input_list, group=group, async_op=async_op)
+ if async_op:
+
+ def wait():
+ comm.wait()
+ return torch.cat(output_list, dim=gather_dim).contiguous()
+
+ return wait
+ return torch.cat(output_list, dim=gather_dim).contiguous()
+
+
+def _all_to_all_single(x: Tensor, scatter_dim: int, gather_dim: int, group: dist.ProcessGroup, async_op: bool = False):
+ """
+ A function to do all-to-all on the first two dim
+ """
+ sp_world_size = dist.get_world_size(group)
+ assert scatter_dim <= 1, "scatter_dim must be 0 or 1 when using all_to_all_single!"
+ assert gather_dim <= 1, "gather_dim must be 0 or 1 when using all_to_all_single!"
+ if scatter_dim != 0:
+ gather_dim_bef = x.shape[gather_dim]
+ scatter_dim_bef = x.shape[scatter_dim]
+ x = (
+ x.reshape([gather_dim_bef, sp_world_size, scatter_dim_bef // sp_world_size] + list(x.shape[2:]))
+ .transpose(0, 1)
+ .reshape([gather_dim_bef * sp_world_size, scatter_dim_bef // sp_world_size] + list(x.shape[2:]))
+ .contiguous()
+ )
+
+ output = torch.empty_like(x)
+ comm = dist.all_to_all_single(output, x.contiguous(), group=group, async_op=async_op)
+
+ if async_op:
+
+ def wait():
+ comm.wait()
+ if scatter_dim == 0:
+ return torch.cat(output.split(x.size(0) // sp_world_size), dim=gather_dim)
+ else:
+ return output
+
+ return wait
+
+ if scatter_dim == 0:
+ output = torch.cat(output.split(x.size(0) // sp_world_size), dim=gather_dim)
+ return output
+
+
+def gather_heads_scatter_seq(x: Tensor, head_dim: int, seq_dim: int) -> Tensor:
+ """
+ A func to sync attention result with alltoall in sequence parallel
+ """
+ group = get_unified_parallel_group()
+ if not group:
+ return x
+ dim_size = x.size(seq_dim)
+ sp_world = get_unified_parallel_world_size()
+ if dim_size % sp_world != 0:
+ padding_size = sp_world - (dim_size % sp_world)
+ x = pad_tensor(x, seq_dim, padding_size)
+ return SeqAllToAll.apply(group, x, seq_dim, head_dim, False)
+
+
+def unpad_tensor(x: Tensor, dim: int, padding_size: int):
+ slc = [slice(None)] * len(x.shape)
+ slc[dim] = slice(0, -padding_size)
+ return x[slc]
+
+
+class Gather(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx: Any,
+ group: dist.ProcessGroup,
+ local_input: Tensor,
+ dim: int,
+ grad_scale: Optional[bool] = False,
+ ) -> Tensor:
+ ctx.group = group
+ ctx.rank = dist.get_rank(group)
+ ctx.dim = dim
+ ctx.grad_scale = grad_scale
+ seq_world_size = dist.get_world_size(group)
+ ctx.seq_world_size = seq_world_size
+ dim_size = list(local_input.size())
+ split_size = dim_size[0]
+ ctx.part_size = dim_size[dim]
+ dim_size[0] = dim_size[0] * seq_world_size
+ output = torch.empty(dim_size, dtype=local_input.dtype, device=torch.cuda.current_device())
+ dist.all_gather_into_tensor(output, local_input.contiguous(), group=ctx.group)
+ return torch.cat(output.split(split_size), dim=dim)
+
+ @staticmethod
+ def backward(ctx: Any, grad_output: Tensor) -> Tuple[None, Tensor]:
+ if ctx.grad_scale:
+ grad_output = grad_output * ctx.seq_world_size
+ return (
+ None,
+ grad_output.split(ctx.part_size, dim=ctx.dim)[ctx.rank].contiguous(),
+ None,
+ None,
+ )
+
+
+def slice_tensor(tensor, dim, start, end):
+ indices = slice(start, end)
+ return tensor[(slice(None),) * dim + (indices,)]
+
+
+def init_model_shard_cpu_group(sharding_strategy: str, device_mesh: Optional[Tuple] = None):
+ """
+ Initialize CPU process group of model sharding.
+ """
+ global _MODEL_SHARD_CPU_GROUP
+ assert dist.is_initialized()
+ world_size = dist.get_world_size()
+ rank = dist.get_rank()
+ if device_mesh is not None:
+ num_shards_per_group = device_mesh[1]
+ elif "HYBRID" in sharding_strategy:
+ num_shards_per_group = min(8, world_size)
+ else:
+ num_shards_per_group = world_size
+ num_groups = world_size // num_shards_per_group
+ for i in range(num_groups):
+ start_rank = i * num_shards_per_group
+ end_rank = (i + 1) * num_shards_per_group
+ ranks = range(start_rank, end_rank)
+ cpu_group = dist.new_group(ranks, backend="gloo")
+ if rank in ranks:
+ _MODEL_SHARD_CPU_GROUP = cpu_group
diff --git a/humo/common/distributed/basic.py b/humo/common/distributed/basic.py
new file mode 100644
index 0000000000000000000000000000000000000000..364dc73491a3342ab9898c51432180de8d3c3267
--- /dev/null
+++ b/humo/common/distributed/basic.py
@@ -0,0 +1,143 @@
+# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Codes adapted from [SeedVR]
+# https://github.com/ByteDance-Seed/SeedVR/tree/main/common/distributed
+
+"""
+Distributed basic functions.
+"""
+
+import os
+import torch
+from torch import nn
+import torch.distributed as dist
+from torch.nn.parallel import DistributedDataParallel
+from torch.distributed.fsdp._common_utils import _is_fsdp_flattened
+
+
+def get_global_rank() -> int:
+ """
+ Get the global rank, the global index of the GPU.
+ """
+ return int(os.environ.get("RANK", "0"))
+
+
+def get_local_rank() -> int:
+ """
+ Get the local rank, the local index of the GPU.
+ """
+ return int(os.environ.get("LOCAL_RANK", "0"))
+
+
+def get_world_size() -> int:
+ """
+ Get the world size, the total amount of GPUs.
+ """
+ return int(os.environ.get("WORLD_SIZE", "1"))
+
+
+def get_device() -> torch.device:
+ """
+ Get current rank device.
+ """
+ return torch.device("cuda", get_local_rank())
+
+
+def barrier_if_distributed(*args, **kwargs):
+ """
+ Synchronizes all processes if under distributed context.
+ """
+ if dist.is_initialized():
+ return dist.barrier(*args, **kwargs)
+
+
+def init_torch(cudnn_benchmark=True):
+ """
+ Common PyTorch initialization configuration.
+ """
+ torch.backends.cuda.matmul.allow_tf32 = True
+ torch.backends.cudnn.allow_tf32 = True
+ torch.backends.cudnn.benchmark = cudnn_benchmark
+ torch.cuda.set_device(get_local_rank())
+ dist.init_process_group(
+ backend="nccl",
+ rank=get_global_rank(),
+ world_size=get_world_size(),
+ )
+
+
+def convert_to_ddp(module: torch.nn.Module, **kwargs) -> DistributedDataParallel:
+ return DistributedDataParallel(
+ module=module,
+ device_ids=[get_local_rank()],
+ output_device=get_local_rank(),
+ **kwargs,
+ )
+
+
+def meta_param_init_fn(module: nn.Module) -> None:
+ """
+ Used for model inited onto meta device.
+ Init meta param/buffer with empty tensor.
+ We don't care numerical correctness in this func.
+ FSDP will sync param/buffer state from rank0 to the other ranks.
+ """
+
+ with torch.no_grad():
+ for submodule in module.modules():
+ for param_name, param in submodule.named_parameters(recurse=False):
+ if not _is_fsdp_flattened(param) and param.is_meta:
+ materialized_param = nn.Parameter(torch.empty_like(param, device="cpu"))
+ setattr(submodule, param_name, materialized_param)
+ for buffer_name, buffer in submodule.named_buffers(recurse=False):
+ if not _is_fsdp_flattened(buffer) and buffer.is_meta:
+ materialized_param = torch.empty_like(buffer, device="cpu")
+ setattr(submodule, buffer_name, materialized_param)
+ torch.cuda.empty_cache()
+
+
+def meta_non_persistent_buffer_init_fn(module: nn.Module) -> nn.Module:
+ """
+ Materialize meta device buffers that are not persistent in state_dict.
+ Handles special cases like RotaryEmbedding.freqs.
+ """
+ with torch.no_grad():
+ for submodule in module.modules():
+ if hasattr(submodule, "freqs"):
+ freqs = getattr(submodule, "freqs")
+ if isinstance(freqs, torch.Tensor) and freqs.is_meta:
+ dim = submodule.dim
+ def rope_params(max_seq_len, dim, theta=10000):
+ assert dim % 2 == 0
+ freqs = torch.outer(
+ torch.arange(max_seq_len),
+ 1.0 / torch.pow(theta,
+ torch.arange(0, dim, 2).to(torch.float32).div(dim)))
+ freqs = torch.polar(torch.ones_like(freqs), freqs)
+ return freqs
+
+ dim = 5120 # 1536
+ num_heads = 40 # 12
+ # dim = 1536
+ # num_heads = 12
+ d = dim // num_heads
+ freqs_tensor = torch.cat([
+ rope_params(1024, d - 4 * (d // 6)),
+ rope_params(1024, 2 * (d // 6)),
+ rope_params(1024, 2 * (d // 6))
+ ], dim=1).to(dtype=torch.cfloat, device="cpu")
+
+ setattr(submodule, "freqs", freqs_tensor)
+ print(f"Successfully materialized freqs for {submodule.__class__.__name__}")
+
+ assert not any(b.is_meta for n, b in module.named_buffers())
+ return module
diff --git a/humo/common/logger.py b/humo/common/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..f14adffe7610d97051e97957a34f021e1bd2f6bc
--- /dev/null
+++ b/humo/common/logger.py
@@ -0,0 +1,44 @@
+# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Codes adapted from [SeedVR]
+# https://github.com/ByteDance-Seed/SeedVR/blob/main/common/logger.py
+
+"""
+Logging utility functions.
+"""
+
+import logging
+import sys
+from typing import Optional
+
+from common.distributed import get_global_rank, get_local_rank, get_world_size
+
+_default_handler = logging.StreamHandler(sys.stdout)
+_default_handler.setFormatter(
+ logging.Formatter(
+ "%(asctime)s "
+ + (f"[Rank:{get_global_rank()}]" if get_world_size() > 1 else "")
+ + (f"[LocalRank:{get_local_rank()}]" if get_world_size() > 1 else "")
+ + "[%(threadName).12s][%(name)s][%(levelname).5s] "
+ + "%(message)s"
+ )
+)
+
+
+def get_logger(name: Optional[str] = None) -> logging.Logger:
+ """
+ Get a logger.
+ """
+ logger = logging.getLogger(name)
+ logger.addHandler(_default_handler)
+ logger.setLevel(logging.INFO)
+ return logger
diff --git a/humo/configs/inference/generate.yaml b/humo/configs/inference/generate.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..aad930e312f613930e80e60047dd5669369c91c2
--- /dev/null
+++ b/humo/configs/inference/generate.yaml
@@ -0,0 +1,78 @@
+__object__:
+ path: humo.generate
+ name: Generator
+
+dit:
+ model:
+ __inherit__: humo/configs/models/Wan_14B_I2V.yaml
+ __object__:
+ path: humo.models.wan_modules.model_humo
+ name: WanModel
+ insert_audio: True
+ zero_vae_path: ./weights/HuMo/zero_vae_129frame.pt
+ zero_vae_720p_path: ./weights/HuMo/zero_vae_720p_161frame.pt
+ checkpoint_dir: ./weights/HuMo/HuMo-17B
+ compile: False
+ init_with_meta_device: True
+ gradient_checkpoint: True
+ fsdp:
+ sharding_strategy: _HYBRID_SHARD_ZERO2
+ sp_size: 1
+ dtype: bfloat16
+
+vae:
+ checkpoint: ./weights/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth
+ vae_stride: [ 4, 8, 8 ]
+ scaling_factor: 0.9152
+ compile: False
+ grouping: True
+ use_sample: False
+ dtype: bfloat16
+
+text:
+ t5_checkpoint: ./weights/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth
+ t5_tokenizer: ./weights/Wan2.1-T2V-1.3B/google/umt5-xxl
+ dropout: 0.1
+ dtype: bfloat16
+ fsdp:
+ enabled: True
+ sharding_strategy: HYBRID_SHARD
+
+diffusion:
+ schedule:
+ type: lerp
+ T: 1000.0
+ sampler:
+ type: euler
+ prediction_type: v_lerp
+ timesteps:
+ training:
+ type: logitnormal
+ loc: 0.0
+ scale: 1.0
+ sampling:
+ type: uniform_trailing
+ steps: 50
+ shift: 5.0
+
+audio:
+ vocal_separator: ./weights/HuMo/audio_separator/Kim_Vocal_2.onnx
+ wav2vec_model: ./weights/whisper-large-v3
+
+generation:
+ mode: "TIA" # TA, TIA
+ extract_audio_feat: True
+ seed: 666666
+ frames: 97
+ fps: 25
+ height: 480 # 720 480
+ width: 832 # 1280 832
+ batch_size: 1
+ sequence_parallel: 8
+ output:
+ dir: ./output
+ # positive_prompt: ./examples/test_case.json
+ sample_neg_prompt: '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
+ scale_a: 5.5
+ scale_t: 5.0
+ step_change: 980
\ No newline at end of file
diff --git a/humo/configs/inference/generate_1_7B.yaml b/humo/configs/inference/generate_1_7B.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2a9d1a88ba457e67d01069fa34d45b8e0896a896
--- /dev/null
+++ b/humo/configs/inference/generate_1_7B.yaml
@@ -0,0 +1,76 @@
+__object__:
+ path: humo.generate_1_7B
+ name: Generator
+
+dit:
+ model:
+ __inherit__: humo/configs/models/Wan_1.3B.yaml
+ __object__:
+ path: humo.models.wan_modules.model_humo
+ name: WanModel
+ insert_audio: True
+ zero_vae_path: ./weights/HuMo/zero_vae_129frame.pt
+ zero_vae_720p_path: ./weights/HuMo/zero_vae_720p_161frame.pt
+ checkpoint_dir: ./weights/HuMo/HuMo-1.7B/ema.pth #./weights/HuMo/HuMo-17B
+ compile: False
+ init_with_meta_device: True
+ gradient_checkpoint: True
+ fsdp:
+ sharding_strategy: _HYBRID_SHARD_ZERO2
+ sp_size: 1
+
+vae:
+ checkpoint: ./weights/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth
+ vae_stride: [ 4, 8, 8 ]
+ scaling_factor: 0.9152
+ compile: False
+ grouping: True
+ use_sample: False
+ dtype: bfloat16
+
+text:
+ t5_checkpoint: ./weights/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth
+ t5_tokenizer: ./weights/Wan2.1-T2V-1.3B/google/umt5-xxl
+ dropout: 0.1
+ dtype: bfloat16
+ fsdp:
+ enabled: True
+ sharding_strategy: HYBRID_SHARD
+
+diffusion:
+ schedule:
+ type: lerp
+ T: 1000.0
+ sampler:
+ type: euler
+ prediction_type: v_lerp
+ timesteps:
+ training:
+ type: logitnormal
+ loc: 0.0
+ scale: 1.0
+ sampling:
+ type: uniform_trailing
+ steps: 50
+ shift: 5.0
+
+audio:
+ vocal_separator: ./weights/audio_separator/Kim_Vocal_2.onnx
+ wav2vec_model: ./weights/whisper-large-v3
+
+generation:
+ mode: "TIA" # TA, TIA
+ extract_audio_feat: True
+ seed: 666666
+ frames: 97
+ fps: 25
+ height: 720 # 480
+ width: 1280 # 832
+ batch_size: 1
+ output:
+ dir: ./output
+ sample_neg_prompt: '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
+ scale_t: 7.5
+ scale_i: 4.0
+ scale_a: 7.5
+ # step_change: 980
\ No newline at end of file
diff --git a/humo/configs/models/Wan_1.3B.yaml b/humo/configs/models/Wan_1.3B.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8202790fff334d1d6dae0343a95d2f8a0bbf571c
--- /dev/null
+++ b/humo/configs/models/Wan_1.3B.yaml
@@ -0,0 +1,17 @@
+__object__:
+ path: ???
+ name: ???
+ args: as_params
+
+text_len: 512
+patch_size: [ 1, 2, 2 ]
+dim: 1536
+ffn_dim: 8960
+freq_dim: 256
+model_type: "t2v"
+num_heads: 12
+num_layers: 30
+window_size: [ -1, -1 ]
+qk_norm: True
+cross_attn_norm: True
+eps: 1e-6
\ No newline at end of file
diff --git a/humo/configs/models/Wan_1.3B_I2V.yaml b/humo/configs/models/Wan_1.3B_I2V.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9c3404757dc718f71b69427349479ff9bc33d761
--- /dev/null
+++ b/humo/configs/models/Wan_1.3B_I2V.yaml
@@ -0,0 +1,18 @@
+__object__:
+ path: ???
+ name: ???
+ args: as_params
+
+text_len: 512
+patch_size: [ 1, 2, 2 ]
+dim: 1536
+ffn_dim: 8960
+freq_dim: 256
+in_dim: 36
+model_type: "i2v"
+num_heads: 12
+num_layers: 30
+window_size: [ -1, -1 ]
+qk_norm: True
+cross_attn_norm: True
+eps: 1e-6
diff --git a/humo/configs/models/Wan_14B.yaml b/humo/configs/models/Wan_14B.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..5da9a868af9ae629ecd1a32638a94cc635b1cad9
--- /dev/null
+++ b/humo/configs/models/Wan_14B.yaml
@@ -0,0 +1,17 @@
+__object__:
+ path: ???
+ name: ???
+ args: as_params
+
+text_len: 512
+patch_size: [ 1, 2, 2 ]
+dim: 5120
+ffn_dim: 13824
+freq_dim: 256
+model_type: "t2v"
+num_heads: 40
+num_layers: 40
+window_size: [ -1, -1 ]
+qk_norm: True
+cross_attn_norm: True
+eps: 1e-6
\ No newline at end of file
diff --git a/humo/configs/models/Wan_14B_I2V.yaml b/humo/configs/models/Wan_14B_I2V.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c6d7d0082e688d151a58257a5371422b3e13bd31
--- /dev/null
+++ b/humo/configs/models/Wan_14B_I2V.yaml
@@ -0,0 +1,18 @@
+__object__:
+ path: ???
+ name: ???
+ args: as_params
+
+text_len: 512
+patch_size: [ 1, 2, 2 ]
+dim: 5120
+ffn_dim: 13824
+freq_dim: 256
+in_dim: 36
+model_type: "i2v"
+num_heads: 40
+num_layers: 40
+window_size: [ -1, -1 ]
+qk_norm: True
+cross_attn_norm: True
+eps: 1e-6
\ No newline at end of file
diff --git a/humo/generate.py b/humo/generate.py
new file mode 100644
index 0000000000000000000000000000000000000000..c185f15a427ece6f72c75caa7dfd1e32ae278aae
--- /dev/null
+++ b/humo/generate.py
@@ -0,0 +1,984 @@
+# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Inference codes adapted from [SeedVR]
+# https://github.com/ByteDance-Seed/SeedVR/blob/main/projects/inference_seedvr2_7b.py
+
+import math
+import os
+import gc
+import random
+import sys
+import mediapy
+import numpy as np
+import torch
+import torch.distributed as dist
+from omegaconf import DictConfig, ListConfig, OmegaConf
+from einops import rearrange
+from omegaconf import OmegaConf
+from PIL import Image, ImageOps
+from torchvision.transforms import ToTensor
+from tqdm import tqdm
+from torch.distributed.device_mesh import init_device_mesh
+from torch.distributed.fsdp import (
+ BackwardPrefetch,
+ FullyShardedDataParallel,
+ MixedPrecision,
+ ShardingStrategy,
+)
+from common.distributed import (
+ get_device,
+ get_global_rank,
+ get_local_rank,
+ meta_param_init_fn,
+ meta_non_persistent_buffer_init_fn,
+ init_torch,
+)
+from common.distributed.advanced import (
+ init_unified_parallel,
+ get_unified_parallel_world_size,
+ get_sequence_parallel_rank,
+ init_model_shard_cpu_group,
+)
+from common.logger import get_logger
+from common.config import create_object
+from common.distributed import get_device, get_global_rank
+from torchvision.transforms import Compose, Normalize, ToTensor
+from humo.models.wan_modules.t5 import T5EncoderModel
+from humo.models.wan_modules.vae import WanVAE
+from humo.models.utils.utils import tensor_to_video, prepare_json_dataset
+from contextlib import contextmanager
+import torch.cuda.amp as amp
+from humo.models.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
+from humo.utils.audio_processor_whisper import AudioProcessor
+from humo.utils.wav2vec import linear_interpolation_fps
+from torchao.quantization import quantize_
+
+import torch._dynamo as dynamo
+dynamo.config.capture_scalar_outputs = True
+torch.set_float32_matmul_precision("high")
+
+import torch
+import torch.nn as nn
+import transformer_engine.pytorch as te
+
+image_transform = Compose([
+ ToTensor(),
+ Normalize(mean=0.5, std=0.5),
+])
+
+SIZE_CONFIGS = {
+ '720*1280': (720, 1280),
+ '1280*720': (1280, 720),
+ '480*832': (480, 832),
+ '832*480': (832, 480),
+ '1024*1024': (1024, 1024),
+}
+
+def clever_format(nums, format="%.2f"):
+ from typing import Iterable
+ if not isinstance(nums, Iterable):
+ nums = [nums]
+ clever_nums = []
+ for num in nums:
+ if num > 1e12:
+ clever_nums.append(format % (num / 1e12) + "T")
+ elif num > 1e9:
+ clever_nums.append(format % (num / 1e9) + "G")
+ elif num > 1e6:
+ clever_nums.append(format % (num / 1e6) + "M")
+ elif num > 1e3:
+ clever_nums.append(format % (num / 1e3) + "K")
+ else:
+ clever_nums.append(format % num + "B")
+
+ clever_nums = clever_nums[0] if len(clever_nums) == 1 else (*clever_nums,)
+
+ return clever_nums
+
+
+
+# --- put near your imports ---
+import torch
+import torch.nn as nn
+import contextlib
+import transformer_engine.pytorch as te
+
+# FP8 autocast compatibility for different TE versions
+try:
+ # Preferred modern API
+ from transformer_engine.pytorch import fp8_autocast
+ try:
+ # Newer TE: use recipe-based API
+ from transformer_engine.common.recipe import DelayedScaling, Format
+ def make_fp8_ctx(enabled: bool = True):
+ if not enabled:
+ return contextlib.nullcontext()
+ fp8_recipe = DelayedScaling(fp8_format=Format.E4M3) # E4M3 format
+ return fp8_autocast(enabled=True, fp8_recipe=fp8_recipe)
+ except Exception:
+ # Very old variant that might still accept fp8_format directly
+ def make_fp8_ctx(enabled: bool = True):
+ # If TE doesn't have FP8Format, just no-op
+ if not hasattr(te, "FP8Format"):
+ return contextlib.nullcontext()
+ return te.fp8_autocast(enabled=enabled, fp8_format=te.FP8Format.E4M3)
+except Exception:
+ # TE not present or totally incompatible — no-op
+ def make_fp8_ctx(enabled: bool = True):
+ return contextlib.nullcontext()
+
+
+# TE sometimes exposes Linear at different paths; this normalizes it.
+try:
+ TELinear = te.Linear
+except AttributeError: # very old layouts
+ from transformer_engine.pytorch.modules.linear import Linear as TELinear # type: ignore
+
+# --- near imports ---
+import torch
+import torch.nn as nn
+import transformer_engine.pytorch as te
+
+try:
+ TELinear = te.Linear
+except AttributeError:
+ from transformer_engine.pytorch.modules.linear import Linear as TELinear # type: ignore
+
+import torch
+import torch.nn as nn
+import transformer_engine.pytorch as te
+
+try:
+ TELinear = te.Linear
+except AttributeError:
+ from transformer_engine.pytorch.modules.linear import Linear as TELinear # type: ignore
+
+def _default_te_allow(fullname: str, lin: nn.Linear) -> bool:
+ """
+ Allow TE only where it's shape-safe & beneficial.
+ Skip small/special layers (time/timestep/pos embeds, heads).
+ Enforce multiples of 16 for in/out features (FP8 kernel friendly).
+ Also skip very small projections likely to see M=1.
+ """
+ blocked_keywords = (
+ "time_embedding", "timestep", "time_embed",
+ "time_projection", "pos_embedding", "pos_embed",
+ "to_logits", "logits", "final_proj", "proj_out", "output_projection",
+ )
+ if any(k in fullname for k in blocked_keywords):
+ return False
+
+ # TE FP8 kernels like K, N divisible by 16
+ if lin.in_features % 16 != 0 or lin.out_features % 16 != 0:
+ return False
+
+ # Heuristic: avoid tiny layers; keeps attention/MLP, skips small MLPs
+ if lin.in_features < 512 or lin.out_features < 512:
+ return False
+
+ # Whitelist: only convert inside transformer blocks if you know their prefix
+ # This further reduces risk of catching special heads elsewhere.
+ allowed_context = ("blocks", "layers", "transformer", "attn", "mlp", "ffn")
+ if not any(tok in fullname for tok in allowed_context):
+ return False
+
+ return True
+
+@torch.no_grad()
+def convert_linears_to_te_fp8(module: nn.Module, allow_pred=_default_te_allow, _prefix=""):
+ for name, child in list(module.named_children()):
+ full = f"{_prefix}.{name}" if _prefix else name
+ convert_linears_to_te_fp8(child, allow_pred, full)
+
+ if isinstance(child, nn.Linear):
+ if allow_pred is not None and not allow_pred(full, child):
+ continue
+
+ te_lin = TELinear(
+ in_features=child.in_features,
+ out_features=child.out_features,
+ bias=(child.bias is not None),
+ params_dtype=torch.bfloat16,
+ ).to(child.weight.device)
+
+ te_lin.weight.copy_(child.weight.to(te_lin.weight.dtype))
+ if child.bias is not None:
+ te_lin.bias.copy_(child.bias.to(te_lin.bias.dtype))
+
+ setattr(module, name, te_lin)
+ return module
+
+class Generator():
+ def __init__(self, config: DictConfig):
+ self.config = config.copy()
+ OmegaConf.set_readonly(self.config, True)
+ self.logger = get_logger(self.__class__.__name__)
+
+ # init_torch(cudnn_benchmark=False)
+ self.configure_models()
+
+ def entrypoint(self):
+
+ self.inference_loop()
+
+ def get_fsdp_sharding_config(self, sharding_strategy, device_mesh_config):
+ device_mesh = None
+ fsdp_strategy = ShardingStrategy[sharding_strategy]
+ if (
+ fsdp_strategy in [ShardingStrategy._HYBRID_SHARD_ZERO2, ShardingStrategy.HYBRID_SHARD]
+ and device_mesh_config is not None
+ ):
+ device_mesh = init_device_mesh("cuda", tuple(device_mesh_config))
+ return device_mesh, fsdp_strategy
+
+
+ def configure_models(self):
+ self.configure_dit_model(device="cuda")
+
+ self.dit.eval().to("cuda")
+ convert_linears_to_te_fp8(self.dit)
+
+ self.dit = torch.compile(self.dit, )
+
+
+ self.configure_vae_model(device="cuda")
+ if self.config.generation.get('extract_audio_feat', False):
+ self.configure_wav2vec(device="cpu")
+ self.configure_text_model(device="cuda")
+
+ # # Initialize fsdp.
+ # self.configure_dit_fsdp_model()
+ # self.configure_text_fsdp_model()
+
+ # quantize_(self.text_encoder, Int8WeightOnlyConfig())
+ # quantize_(self.dit, Float8DynamicActivationFloat8WeightConfig())
+
+
+ def configure_dit_model(self, device=get_device()):
+
+ init_unified_parallel(self.config.dit.sp_size)
+ self.sp_size = get_unified_parallel_world_size()
+
+ # Create DiT model on meta, then mark dtype as bfloat16 (no real allocation yet).
+ init_device = "meta"
+ with torch.device(init_device):
+ self.dit = create_object(self.config.dit.model)
+ self.dit = self.dit.to(dtype=torch.bfloat16) # or: self.dit.bfloat16()
+ self.logger.info(f"Load DiT model on {init_device}.")
+ self.dit.eval().requires_grad_(False)
+
+ # Load dit checkpoint.
+ path = self.config.dit.checkpoint_dir
+
+ def _cast_state_dict_to_bf16(state):
+ for k, v in state.items():
+ if isinstance(v, torch.Tensor) and v.is_floating_point():
+ state[k] = v.to(dtype=torch.bfloat16, copy=False)
+ return state
+
+ if path.endswith(".pth"):
+ # Load to CPU first; we’ll move the model later.
+ state = torch.load(path, map_location="cpu", mmap=True)
+ state = _cast_state_dict_to_bf16(state)
+ missing_keys, unexpected_keys = self.dit.load_state_dict(state, strict=False, assign=True)
+ self.logger.info(
+ f"dit loaded from {path}. Missing keys: {len(missing_keys)}, Unexpected keys: {len(unexpected_keys)}"
+ )
+ else:
+ from safetensors.torch import load_file
+ import json
+ def load_custom_sharded_weights(model_dir, base_name):
+ index_path = f"{model_dir}/{base_name}.safetensors.index.json"
+ with open(index_path, "r") as f:
+ index = json.load(f)
+ weight_map = index["weight_map"]
+ shard_files = set(weight_map.values())
+ state_dict = {}
+ for shard_file in shard_files:
+ shard_path = f"{model_dir}/{shard_file}"
+ # Load on CPU, then cast to bf16; we’ll move the whole module later.
+ shard_state = load_file(shard_path, device="cpu")
+ shard_state = {k: (v.to(dtype=torch.bfloat16, copy=False) if v.is_floating_point() else v)
+ for k, v in shard_state.items()}
+ state_dict.update(shard_state)
+ return state_dict
+
+ state = load_custom_sharded_weights(path, 'humo')
+ self.dit.load_state_dict(state, strict=False, assign=True)
+
+ self.dit = meta_non_persistent_buffer_init_fn(self.dit)
+
+ target_device = get_device() if device in [get_device(), "cuda"] else device
+ self.dit.to(target_device) # dtype already bf16
+
+ # Print model size.
+ params = sum(p.numel() for p in self.dit.parameters())
+ self.logger.info(
+ f"[RANK:{get_global_rank()}] DiT Parameters: {clever_format(params, '%.3f')}"
+ )
+
+
+ def configure_vae_model(self, device=get_device()):
+ self.vae_stride = self.config.vae.vae_stride
+ self.vae = WanVAE(
+ vae_pth=self.config.vae.checkpoint,
+ device=device)
+
+ if self.config.generation.height == 480:
+ self.zero_vae = torch.load(self.config.dit.zero_vae_path)
+ elif self.config.generation.height == 720:
+ self.zero_vae = torch.load(self.config.dit.zero_vae_720p_path)
+ else:
+ raise ValueError(f"Unsupported height {self.config.generation.height} for zero-vae.")
+
+ def configure_wav2vec(self, device=get_device()):
+ audio_separator_model_file = self.config.audio.vocal_separator
+ wav2vec_model_path = self.config.audio.wav2vec_model
+
+ self.audio_processor = AudioProcessor(
+ 16000,
+ 25,
+ wav2vec_model_path,
+ "all",
+ audio_separator_model_file,
+ None, # not seperate
+ os.path.join(self.config.generation.output.dir, "vocals"),
+ device=device,
+ )
+
+ def configure_text_model(self, device=get_device()):
+ self.text_encoder = T5EncoderModel(
+ text_len=self.config.dit.model.text_len,
+ dtype=torch.bfloat16,
+ device=device,
+ checkpoint_path=self.config.text.t5_checkpoint,
+ tokenizer_path=self.config.text.t5_tokenizer,
+ )
+
+
+ def configure_dit_fsdp_model(self):
+ from humo.models.wan_modules.model_humo import WanAttentionBlock
+
+ dit_blocks = (WanAttentionBlock,)
+
+ # Init model_shard_cpu_group for saving checkpoint with sharded state_dict.
+ init_model_shard_cpu_group(
+ self.config.dit.fsdp.sharding_strategy,
+ self.config.dit.fsdp.get("device_mesh", None),
+ )
+
+ # Assert that dit has wrappable blocks.
+ assert any(isinstance(m, dit_blocks) for m in self.dit.modules())
+
+ # Define wrap policy on all dit blocks.
+ def custom_auto_wrap_policy(module, recurse, *args, **kwargs):
+ return recurse or isinstance(module, dit_blocks)
+
+ # Configure FSDP settings.
+ device_mesh, fsdp_strategy = self.get_fsdp_sharding_config(
+ self.config.dit.fsdp.sharding_strategy,
+ self.config.dit.fsdp.get("device_mesh", None),
+ )
+ settings = dict(
+ auto_wrap_policy=custom_auto_wrap_policy,
+ sharding_strategy=fsdp_strategy,
+ backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
+ device_id=get_local_rank(),
+ use_orig_params=False,
+ sync_module_states=True,
+ forward_prefetch=True,
+ limit_all_gathers=False, # False for ZERO2.
+ mixed_precision=MixedPrecision(
+ param_dtype=torch.bfloat16,
+ reduce_dtype=torch.float32,
+ buffer_dtype=torch.float32,
+ ),
+ device_mesh=device_mesh,
+ param_init_fn=meta_param_init_fn,
+ )
+
+ # Apply FSDP.
+ self.dit = FullyShardedDataParallel(self.dit, **settings)
+ # self.dit.to(get_device())
+
+
+ def configure_text_fsdp_model(self):
+ # If FSDP is not enabled, put text_encoder to GPU and return.
+ if not self.config.text.fsdp.enabled:
+ self.text_encoder.to(get_device())
+ return
+
+ # from transformers.models.t5.modeling_t5 import T5Block
+ from humo.models.wan_modules.t5 import T5SelfAttention
+
+ text_blocks = (torch.nn.Embedding, T5SelfAttention)
+ # text_blocks_names = ("QWenBlock", "QWenModel") # QWen cannot be imported. Use str.
+
+ def custom_auto_wrap_policy(module, recurse, *args, **kwargs):
+ return (
+ recurse
+ or isinstance(module, text_blocks)
+ )
+
+ # Apply FSDP.
+ text_encoder_dtype = getattr(torch, self.config.text.dtype)
+ device_mesh, fsdp_strategy = self.get_fsdp_sharding_config(
+ self.config.text.fsdp.sharding_strategy,
+ self.config.text.fsdp.get("device_mesh", None),
+ )
+ self.text_encoder = FullyShardedDataParallel(
+ module=self.text_encoder,
+ auto_wrap_policy=custom_auto_wrap_policy,
+ sharding_strategy=fsdp_strategy,
+ backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
+ device_id=get_local_rank(),
+ use_orig_params=False,
+ sync_module_states=False,
+ forward_prefetch=True,
+ limit_all_gathers=True,
+ mixed_precision=MixedPrecision(
+ param_dtype=text_encoder_dtype,
+ reduce_dtype=text_encoder_dtype,
+ buffer_dtype=text_encoder_dtype,
+ ),
+ device_mesh=device_mesh,
+ )
+ self.text_encoder.to(get_device()).requires_grad_(False)
+
+
+ def load_image_latent_ref_id(self, path: str, size, device):
+ # Load size.
+ h, w = size[1], size[0]
+
+ # Load image.
+ if len(path) > 1 and not isinstance(path, str):
+ ref_vae_latents = []
+ for image_path in path:
+ with Image.open(image_path) as img:
+ img = img.convert("RGB")
+
+ # Calculate the required size to keep aspect ratio and fill the rest with padding.
+ img_ratio = img.width / img.height
+ target_ratio = w / h
+
+ if img_ratio > target_ratio: # Image is wider than target
+ new_width = w
+ new_height = int(new_width / img_ratio)
+ else: # Image is taller than target
+ new_height = h
+ new_width = int(new_height * img_ratio)
+
+ # img = img.resize((new_width, new_height), Image.ANTIALIAS)
+ img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
+
+ # Create a new image with the target size and place the resized image in the center
+ delta_w = w - img.size[0]
+ delta_h = h - img.size[1]
+ padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
+ new_img = ImageOps.expand(img, padding, fill=(255, 255, 255))
+
+ # Transform to tensor and normalize.
+ transform = Compose(
+ [
+ ToTensor(),
+ Normalize(0.5, 0.5),
+ ]
+ )
+ new_img = transform(new_img)
+ # img_vae_latent = self.vae_encode([new_img.unsqueeze(1)])[0]
+ img_vae_latent = self.vae.encode([new_img.unsqueeze(1)], device)
+ ref_vae_latents.append(img_vae_latent[0])
+
+ return [torch.cat(ref_vae_latents, dim=1)]
+ else:
+ if not isinstance(path, str):
+ path = path[0]
+ with Image.open(path) as img:
+ img = img.convert("RGB")
+
+ # Calculate the required size to keep aspect ratio and fill the rest with padding.
+ img_ratio = img.width / img.height
+ target_ratio = w / h
+
+ if img_ratio > target_ratio: # Image is wider than target
+ new_width = w
+ new_height = int(new_width / img_ratio)
+ else: # Image is taller than target
+ new_height = h
+ new_width = int(new_height * img_ratio)
+
+ # img = img.resize((new_width, new_height), Image.ANTIALIAS)
+ img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
+
+ # Create a new image with the target size and place the resized image in the center
+ delta_w = w - img.size[0]
+ delta_h = h - img.size[1]
+ padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
+ new_img = ImageOps.expand(img, padding, fill=(255, 255, 255))
+
+ # Transform to tensor and normalize.
+ transform = Compose(
+ [
+ ToTensor(),
+ Normalize(0.5, 0.5),
+ ]
+ )
+ new_img = transform(new_img)
+ img_vae_latent = self.vae.encode([new_img.unsqueeze(1)], device)
+
+ # Vae encode.
+ return img_vae_latent
+
+ def get_audio_emb_window(self, audio_emb, frame_num, frame0_idx, audio_shift=2):
+ zero_audio_embed = torch.zeros((audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device)
+ zero_audio_embed_3 = torch.zeros((3, audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device) # device=audio_emb.device
+ iter_ = 1 + (frame_num - 1) // 4
+ audio_emb_wind = []
+ for lt_i in range(iter_):
+ if lt_i == 0:
+ st = frame0_idx + lt_i - 2
+ ed = frame0_idx + lt_i + 3
+ wind_feat = torch.stack([
+ audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed
+ for i in range(st, ed)
+ ], dim=0)
+ wind_feat = torch.cat((zero_audio_embed_3, wind_feat), dim=0)
+ else:
+ st = frame0_idx + 1 + 4 * (lt_i - 1) - audio_shift
+ ed = frame0_idx + 1 + 4 * lt_i + audio_shift
+ wind_feat = torch.stack([
+ audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed
+ for i in range(st, ed)
+ ], dim=0)
+ audio_emb_wind.append(wind_feat)
+ audio_emb_wind = torch.stack(audio_emb_wind, dim=0)
+
+ return audio_emb_wind, ed - audio_shift
+
+ def audio_emb_enc(self, audio_emb, wav_enc_type="whisper"):
+ if wav_enc_type == "wav2vec":
+ feat_merge = audio_emb
+ elif wav_enc_type == "whisper":
+ feat0 = linear_interpolation_fps(audio_emb[:, :, 0: 8].mean(dim=2), 50, 25)
+ feat1 = linear_interpolation_fps(audio_emb[:, :, 8: 16].mean(dim=2), 50, 25)
+ feat2 = linear_interpolation_fps(audio_emb[:, :, 16: 24].mean(dim=2), 50, 25)
+ feat3 = linear_interpolation_fps(audio_emb[:, :, 24: 32].mean(dim=2), 50, 25)
+ feat4 = linear_interpolation_fps(audio_emb[:, :, 32], 50, 25)
+ feat_merge = torch.stack([feat0, feat1, feat2, feat3, feat4], dim=2)[0]
+ else:
+ raise ValueError(f"Unsupported wav_enc_type: {wav_enc_type}")
+
+ return feat_merge
+
+ def parse_output(self, output):
+ latent = output[0]
+ mask = None
+ return latent, mask
+
+ def forward_tia(self, latents, timestep, t, step_change, arg_tia, arg_ti, arg_i, arg_null):
+ pos_tia, _ = self.parse_output(self.dit(
+ latents, t=timestep, **arg_tia
+ ))
+ torch.cuda.empty_cache()
+
+ pos_ti, _ = self.parse_output(self.dit(
+ latents, t=timestep, **arg_ti
+ ))
+ torch.cuda.empty_cache()
+
+ if t > step_change:
+ neg, _ = self.parse_output(self.dit(
+ latents, t=timestep, **arg_i
+ )) # img included in null, same with official Wan-2.1
+ torch.cuda.empty_cache()
+
+ noise_pred = self.config.generation.scale_a * (pos_tia - pos_ti) + \
+ self.config.generation.scale_t * (pos_ti - neg) + \
+ neg
+ else:
+ neg, _ = self.parse_output(self.dit(
+ latents, t=timestep, **arg_null
+ )) # img not included in null
+ torch.cuda.empty_cache()
+
+ noise_pred = self.config.generation.scale_a * (pos_tia - pos_ti) + \
+ (self.config.generation.scale_t - 2.0) * (pos_ti - neg) + \
+ neg
+ return noise_pred
+
+ def forward_ti(self, latents, timestep, t, step_change, arg_ti, arg_t, arg_i, arg_null):
+ # Positive with text+image (no audio)
+ pos_ti, _ = self.parse_output(self.dit(
+ latents, t=timestep, **arg_ti
+ ))
+ torch.cuda.empty_cache()
+
+ # Positive with text only (no image, no audio)
+ pos_t, _ = self.parse_output(self.dit(
+ latents, t=timestep, **arg_t
+ ))
+ torch.cuda.empty_cache()
+
+ # Negative branch: before step_change, don't include image in null; after, include image (like Wan-2.1)
+ if t > step_change:
+ neg, _ = self.parse_output(self.dit(
+ latents, t=timestep, **arg_i
+ )) # img included in null
+ else:
+ neg, _ = self.parse_output(self.dit(
+ latents, t=timestep, **arg_null
+ )) # img NOT included in null
+ torch.cuda.empty_cache()
+
+ # Guidance blend: replace "scale_a" below with "scale_i" if you add a separate image scale in config
+ noise_pred = self.config.generation.scale_a * (pos_ti - pos_t) + \
+ self.config.generation.scale_t * (pos_t - neg) + \
+ neg
+ return noise_pred
+
+ def forward_ta(self, latents, timestep, arg_ta, arg_t, arg_null):
+ pos_ta, _ = self.parse_output(self.dit(
+ latents, t=timestep, **arg_ta
+ ))
+ torch.cuda.empty_cache()
+
+ pos_t, _ = self.parse_output(self.dit(
+ latents, t=timestep, **arg_t
+ ))
+ torch.cuda.empty_cache()
+
+ neg, _ = self.parse_output(self.dit(
+ latents, t=timestep, **arg_null
+ ))
+ torch.cuda.empty_cache()
+
+ noise_pred = self.config.generation.scale_a * (pos_ta - pos_t) + \
+ self.config.generation.scale_t * (pos_t - neg) + \
+ neg
+ return noise_pred
+
+ @torch.no_grad()
+ def inference(self,
+ input_prompt,
+ img_path,
+ audio_path,
+ size=(1280, 720),
+ frame_num=81,
+ shift=5.0,
+ sample_solver='unipc',
+ inference_mode='TIA',
+ sampling_steps=50,
+ n_prompt="",
+ seed=-1,
+ tea_cache_l1_thresh = 0.0,
+ device = get_device(),
+ ):
+
+ print("inference started")
+
+ # self.vae.model.to(device=device)
+ if img_path is not None:
+ latents_ref = self.load_image_latent_ref_id(img_path, size, device)
+ else:
+ latents_ref = [torch.zeros(16, 1, size[1]//8, size[0]//8).to(device)]
+
+ # self.vae.model.to(device="cpu")
+
+ print("vae finished")
+
+ latents_ref_neg = [torch.zeros_like(latent_ref) for latent_ref in latents_ref]
+
+ # audio
+ if audio_path is not None:
+ if self.config.generation.extract_audio_feat:
+ self.audio_processor.whisper.to(device=device)
+ audio_emb, audio_length = self.audio_processor.preprocess(audio_path)
+ self.audio_processor.whisper.to(device='cpu')
+ else:
+ audio_emb_path = audio_path.replace(".wav", ".pt")
+ audio_emb = torch.load(audio_emb_path).to(device=device)
+ audio_emb = self.audio_emb_enc(audio_emb, wav_enc_type="whisper")
+ self.logger.info("使用预先提取好的音频特征: %s", audio_emb_path)
+ else:
+ audio_emb = torch.zeros(frame_num, 5, 1280).to(device)
+
+ frame_num = frame_num if frame_num != -1 else audio_length
+ frame_num = 4 * ((frame_num - 1) // 4) + 1
+ audio_emb, _ = self.get_audio_emb_window(audio_emb, frame_num, frame0_idx=0)
+ zero_audio_pad = torch.zeros(latents_ref[0].shape[1], *audio_emb.shape[1:]).to(audio_emb.device)
+ audio_emb = torch.cat([audio_emb, zero_audio_pad], dim=0)
+ audio_emb = [audio_emb.to(device)]
+ audio_emb_neg = [torch.zeros_like(audio_emb[0])]
+
+ # preprocess
+ self.patch_size = self.config.dit.model.patch_size
+ F = frame_num
+ target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1 + latents_ref[0].shape[1],
+ size[1] // self.vae_stride[1],
+ size[0] // self.vae_stride[2])
+
+ seq_len = math.ceil((target_shape[2] * target_shape[3]) /
+ (self.patch_size[1] * self.patch_size[2]) *
+ target_shape[1] / self.sp_size) * self.sp_size
+
+ if n_prompt == "":
+ n_prompt = self.config.generation.sample_neg_prompt
+ seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
+ seed_g = torch.Generator(device=device)
+ seed_g.manual_seed(seed)
+
+ # self.text_encoder.model.to(device)
+ context = self.text_encoder([input_prompt], device)
+ context_null = self.text_encoder([n_prompt], device)
+ # self.text_encoder.model.cpu()
+
+ print("text encoder finished")
+
+ noise = [
+ torch.randn(
+ target_shape[0],
+ target_shape[1], # - latents_ref[0].shape[1],
+ target_shape[2],
+ target_shape[3],
+ dtype=torch.float32,
+ device=device,
+ generator=seed_g)
+ ]
+
+ @contextmanager
+ def noop_no_sync():
+ yield
+
+ no_sync = getattr(self.dit, 'no_sync', noop_no_sync)
+ step_change = self.config.generation.step_change # 980
+
+ # evaluation mode
+ with make_fp8_ctx(True), torch.autocast('cuda', dtype=torch.bfloat16), torch.no_grad(), no_sync():
+
+ if sample_solver == 'unipc':
+ sample_scheduler = FlowUniPCMultistepScheduler(
+ num_train_timesteps=1000,
+ shift=1,
+ use_dynamic_shifting=False)
+ sample_scheduler.set_timesteps(
+ sampling_steps, device=device, shift=shift)
+ timesteps = sample_scheduler.timesteps
+
+ # sample videos
+ latents = noise
+
+ msk = torch.ones(4, target_shape[1], target_shape[2], target_shape[3], device=get_device())
+ msk[:,:-latents_ref[0].shape[1]] = 0
+
+ zero_vae = self.zero_vae[:, :(target_shape[1]-latents_ref[0].shape[1])].to(
+ device=get_device(), dtype=latents_ref[0].dtype)
+ y_c = torch.cat([
+ zero_vae,
+ latents_ref[0]
+ ], dim=1)
+ y_c = [torch.concat([msk, y_c])]
+
+ y_null = self.zero_vae[:, :target_shape[1]].to(
+ device=get_device(), dtype=latents_ref[0].dtype)
+ y_null = [torch.concat([msk, y_null])]
+
+ tea_cache_l1_thresh = tea_cache_l1_thresh
+ tea_cache_model_id = "Wan2.1-T2V-14B"
+
+ arg_null = {'seq_len': seq_len, 'audio': audio_emb_neg, 'y': y_null, 'context': context_null, "tea_cache": TeaCache(sampling_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None and tea_cache_l1_thresh > 0 else None}
+ arg_t = {'seq_len': seq_len, 'audio': audio_emb_neg, 'y': y_null, 'context': context, "tea_cache": TeaCache(sampling_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None and tea_cache_l1_thresh > 0 else None}
+ arg_i = {'seq_len': seq_len, 'audio': audio_emb_neg, 'y': y_c, 'context': context_null, "tea_cache": TeaCache(sampling_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None and tea_cache_l1_thresh > 0 else None}
+ arg_ti = {'seq_len': seq_len, 'audio': audio_emb_neg, 'y': y_c, 'context': context, "tea_cache": TeaCache(sampling_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None and tea_cache_l1_thresh > 0 else None}
+ arg_ta = {'seq_len': seq_len, 'audio': audio_emb, 'y': y_null, 'context': context, "tea_cache": TeaCache(sampling_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None and tea_cache_l1_thresh > 0 else None}
+ arg_tia = {'seq_len': seq_len, 'audio': audio_emb, 'y': y_c, 'context': context, "tea_cache": TeaCache(sampling_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None and tea_cache_l1_thresh > 0 else None}
+
+ torch.cuda.empty_cache()
+ # self.dit.to(device=get_device())
+ for _, t in enumerate(tqdm(timesteps)):
+ timestep = [t]
+ timestep = torch.stack(timestep)
+
+ if inference_mode == "TIA":
+ noise_pred = self.forward_tia(latents, timestep, t, step_change,
+ arg_tia, arg_ti, arg_i, arg_null)
+ elif inference_mode == "TA":
+ noise_pred = self.forward_ta(latents, timestep, arg_ta, arg_t, arg_null)
+ elif inference_mode == "TI":
+ noise_pred = self.forward_ti(latents, timestep, t, step_change,
+ arg_ti, arg_t, arg_i, arg_null)
+ else:
+ raise ValueError(f"Unsupported generation mode: {self.config.generation.mode}")
+
+ temp_x0 = sample_scheduler.step(
+ noise_pred.unsqueeze(0),
+ t,
+ latents[0].unsqueeze(0),
+ return_dict=False,
+ generator=seed_g)[0]
+ latents = [temp_x0.squeeze(0)]
+
+ del timestep
+ torch.cuda.empty_cache()
+
+ x0 = latents
+ x0 = [x0_[:,:-latents_ref[0].shape[1]] for x0_ in x0]
+
+ # if offload_model:
+ # self.dit.cpu()
+
+ print("dit finished")
+
+ torch.cuda.empty_cache()
+ # if get_local_rank() == 0:
+ # self.vae.model.to(device=device)
+ videos = self.vae.decode(x0)
+ # self.vae.model.to(device="cpu")
+
+ print("vae 2 finished")
+
+ del noise, latents, noise_pred
+ del audio_emb, audio_emb_neg, latents_ref, latents_ref_neg, context, context_null
+ del x0, temp_x0
+ del sample_scheduler
+ torch.cuda.empty_cache()
+ gc.collect()
+ torch.cuda.synchronize()
+ if dist.is_initialized():
+ dist.barrier()
+
+ return videos[0] # if get_local_rank() == 0 else None
+
+
+ def inference_loop(self, prompt, ref_img_path, audio_path, output_dir, filename, inference_mode = "TIA", width = 832, height = 480, steps=50, frames = 97, tea_cache_l1_thresh = 0.0, seed = 0):
+
+ video = self.inference(
+ prompt,
+ ref_img_path,
+ audio_path,
+ size=SIZE_CONFIGS[f"{width}*{height}"],
+ frame_num=frames,
+ shift=self.config.diffusion.timesteps.sampling.shift,
+ sample_solver='unipc',
+ sampling_steps=steps,
+ inference_mode = inference_mode,
+ tea_cache_l1_thresh = tea_cache_l1_thresh,
+ seed=seed
+ )
+
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ # Save samples.
+ if get_sequence_parallel_rank() == 0:
+ pathname = self.save_sample(
+ sample=video,
+ audio_path=audio_path,
+ output_dir = output_dir,
+ filename=filename,
+ )
+ self.logger.info(f"Finished {filename}, saved to {pathname}.")
+
+ del video, prompt
+ torch.cuda.empty_cache()
+ gc.collect()
+
+
+ def save_sample(self, *, sample: torch.Tensor, audio_path: str, output_dir: str, filename: str):
+ gen_config = self.config.generation
+ # Prepare file path.
+ extension = ".mp4" if sample.ndim == 4 else ".png"
+ filename += extension
+ pathname = os.path.join(output_dir, filename)
+ # Convert sample.
+ sample = sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).to("cpu", torch.uint8)
+ sample = rearrange(sample, "c t h w -> t h w c")
+ # Save file.
+ if sample.ndim == 4:
+ if audio_path is not None:
+ tensor_to_video(
+ sample.numpy(),
+ pathname,
+ audio_path,
+ fps=gen_config.fps)
+ else:
+ mediapy.write_video(
+ path=pathname,
+ images=sample.numpy(),
+ fps=gen_config.fps,
+ )
+ else:
+ raise ValueError
+ return pathname
+
+
+ def prepare_positive_prompts(self):
+ pos_prompts = self.config.generation.positive_prompt
+ if pos_prompts.endswith(".json"):
+ pos_prompts = prepare_json_dataset(pos_prompts)
+ else:
+ raise NotImplementedError
+ assert isinstance(pos_prompts, ListConfig)
+
+ return pos_prompts
+
+class TeaCache:
+ def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
+ self.num_inference_steps = num_inference_steps
+ self.step = 0
+ self.accumulated_rel_l1_distance = 0
+ self.previous_modulated_input = None
+ self.rel_l1_thresh = rel_l1_thresh
+ self.previous_residual = None
+ self.previous_hidden_states = None
+
+ self.coefficients_dict = {
+ "Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],
+ "Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],
+ "Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01],
+ "Wan2.1-I2V-14B-720P": [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02],
+ }
+ if model_id not in self.coefficients_dict:
+ supported_model_ids = ", ".join([i for i in self.coefficients_dict])
+ raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).")
+ self.coefficients = self.coefficients_dict[model_id]
+
+ def check(self, dit, x, t_mod):
+ modulated_inp = t_mod.clone()
+ if self.step == 0 or self.step == self.num_inference_steps - 1:
+ should_calc = True
+ self.accumulated_rel_l1_distance = 0
+ else:
+ coefficients = self.coefficients
+ rescale_func = np.poly1d(coefficients)
+ self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
+ if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
+ should_calc = False
+ else:
+ should_calc = True
+ self.accumulated_rel_l1_distance = 0
+ self.previous_modulated_input = modulated_inp
+ self.step += 1
+ if self.step == self.num_inference_steps:
+ self.step = 0
+ if should_calc:
+ self.previous_hidden_states = x.clone()
+ return not should_calc
+
+ def store(self, hidden_states):
+ if self.previous_hidden_states is None:
+ return
+ self.previous_residual = hidden_states - self.previous_hidden_states
+ self.previous_hidden_states = None
+
+ def update(self, hidden_states):
+ hidden_states = hidden_states + self.previous_residual
+ return hidden_states
\ No newline at end of file
diff --git a/humo/generate_1_7B.py b/humo/generate_1_7B.py
new file mode 100644
index 0000000000000000000000000000000000000000..debd2cad292734d94ba30deb94725c2be7df302e
--- /dev/null
+++ b/humo/generate_1_7B.py
@@ -0,0 +1,622 @@
+# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Inference codes adapted from [SeedVR]
+# https://github.com/ByteDance-Seed/SeedVR/blob/main/projects/inference_seedvr2_7b.py
+
+import math
+import os
+import gc
+import random
+import sys
+import mediapy
+import torch
+import torch.distributed as dist
+from omegaconf import DictConfig, ListConfig, OmegaConf
+from einops import rearrange
+from omegaconf import OmegaConf
+from PIL import Image, ImageOps
+from torchvision.transforms import ToTensor
+from tqdm import tqdm
+from torch.distributed.device_mesh import init_device_mesh
+from torch.distributed.fsdp import (
+ BackwardPrefetch,
+ FullyShardedDataParallel,
+ MixedPrecision,
+ ShardingStrategy,
+)
+from common.distributed import (
+ get_device,
+ get_global_rank,
+ get_local_rank,
+ meta_param_init_fn,
+ meta_non_persistent_buffer_init_fn,
+ init_torch,
+)
+from common.distributed.advanced import (
+ init_unified_parallel,
+ get_unified_parallel_world_size,
+ get_sequence_parallel_rank,
+ init_model_shard_cpu_group,
+)
+from common.logger import get_logger
+from common.config import create_object
+from common.distributed import get_device, get_global_rank
+from torchvision.transforms import Compose, Normalize, ToTensor
+from humo.models.wan_modules.t5 import T5EncoderModel
+from humo.models.wan_modules.vae import WanVAE
+from humo.models.utils.utils import tensor_to_video, prepare_json_dataset
+from contextlib import contextmanager
+import torch.cuda.amp as amp
+from humo.models.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
+from humo.utils.audio_processor_whisper import AudioProcessor
+from humo.utils.wav2vec import linear_interpolation_fps
+
+
+image_transform = Compose([
+ ToTensor(),
+ Normalize(mean=0.5, std=0.5),
+])
+
+SIZE_CONFIGS = {
+ '720*1280': (720, 1280),
+ '1280*720': (1280, 720),
+ '480*832': (480, 832),
+ '832*480': (832, 480),
+ '1024*1024': (1024, 1024),
+}
+
+def clever_format(nums, format="%.2f"):
+ from typing import Iterable
+ if not isinstance(nums, Iterable):
+ nums = [nums]
+ clever_nums = []
+ for num in nums:
+ if num > 1e12:
+ clever_nums.append(format % (num / 1e12) + "T")
+ elif num > 1e9:
+ clever_nums.append(format % (num / 1e9) + "G")
+ elif num > 1e6:
+ clever_nums.append(format % (num / 1e6) + "M")
+ elif num > 1e3:
+ clever_nums.append(format % (num / 1e3) + "K")
+ else:
+ clever_nums.append(format % num + "B")
+
+ clever_nums = clever_nums[0] if len(clever_nums) == 1 else (*clever_nums,)
+
+ return clever_nums
+
+
+class Generator():
+ def __init__(self, config: DictConfig):
+ self.config = config.copy()
+ OmegaConf.set_readonly(self.config, True)
+ self.logger = get_logger(self.__class__.__name__)
+ self.configure_models()
+
+ # init_torch(cudnn_benchmark=False)
+
+ def get_fsdp_sharding_config(self, sharding_strategy, device_mesh_config):
+ device_mesh = None
+ fsdp_strategy = ShardingStrategy[sharding_strategy]
+ if (
+ fsdp_strategy in [ShardingStrategy._HYBRID_SHARD_ZERO2, ShardingStrategy.HYBRID_SHARD]
+ and device_mesh_config is not None
+ ):
+ device_mesh = init_device_mesh("cuda", tuple(device_mesh_config))
+ return device_mesh, fsdp_strategy
+
+ def configure_models(self):
+ self.configure_dit_model(device="cpu")
+ self.configure_vae_model()
+ if self.config.generation.get('extract_audio_feat', False):
+ self.configure_wav2vec(device="cpu")
+ self.configure_text_model(device="cpu")
+
+ # Initialize fsdp.
+ self.configure_dit_fsdp_model()
+ self.configure_text_fsdp_model()
+
+ def configure_dit_model(self, device=get_device()):
+
+ init_unified_parallel(self.config.dit.sp_size)
+ self.sp_size = get_unified_parallel_world_size()
+
+ # Create dit model.
+ init_device = "meta"
+ with torch.device(init_device):
+ self.dit = create_object(self.config.dit.model)
+ self.logger.info(f"Load DiT model on {init_device}.")
+ self.dit.eval().requires_grad_(False)
+
+ # Load dit checkpoint.
+ path = self.config.dit.checkpoint_dir
+ if path.endswith(".pth"):
+ state = torch.load(path, map_location=device, mmap=True)
+ missing_keys, unexpected_keys = self.dit.load_state_dict(state, strict=False, assign=True)
+ self.logger.info(
+ f"dit loaded from {path}. "
+ f"Missing keys: {len(missing_keys)}, "
+ f"Unexpected keys: {len(unexpected_keys)}"
+ )
+ else:
+ from safetensors.torch import load_file
+ import json
+ def load_custom_sharded_weights(model_dir, base_name, device=device):
+ index_path = f"{model_dir}/{base_name}.safetensors.index.json"
+ with open(index_path, "r") as f:
+ index = json.load(f)
+ weight_map = index["weight_map"]
+ shard_files = set(weight_map.values())
+ state_dict = {}
+ for shard_file in shard_files:
+ shard_path = f"{model_dir}/{shard_file}"
+ shard_state = load_file(shard_path)
+ shard_state = {k: v.to(device) for k, v in shard_state.items()}
+ state_dict.update(shard_state)
+ return state_dict
+ state = load_custom_sharded_weights(path, 'humo', device)
+ self.dit.load_state_dict(state, strict=False, assign=True)
+
+ self.dit = meta_non_persistent_buffer_init_fn(self.dit)
+ if device in [get_device(), "cuda"]:
+ self.dit.to(get_device())
+
+ # Print model size.
+ params = sum(p.numel() for p in self.dit.parameters())
+ self.logger.info(
+ f"[RANK:{get_global_rank()}] DiT Parameters: {clever_format(params, '%.3f')}"
+ )
+
+ def configure_vae_model(self, device=get_device()):
+ self.vae_stride = self.config.vae.vae_stride
+ self.vae = WanVAE(
+ vae_pth=self.config.vae.checkpoint,
+ device=device)
+
+ if self.config.generation.height == 480:
+ self.zero_vae = torch.load(self.config.dit.zero_vae_path)
+ elif self.config.generation.height == 720:
+ self.zero_vae = torch.load(self.config.dit.zero_vae_720p_path)
+ else:
+ raise ValueError(f"Unsupported height {self.config.generation.height} for zero-vae.")
+
+ def configure_wav2vec(self, device=get_device()):
+ audio_separator_model_file = self.config.audio.vocal_separator
+ wav2vec_model_path = self.config.audio.wav2vec_model
+
+ self.audio_processor = AudioProcessor(
+ 16000,
+ 25,
+ wav2vec_model_path,
+ "all",
+ audio_separator_model_file,
+ None, # not seperate
+ os.path.join(self.config.generation.output.dir, "vocals"),
+ device=device,
+ )
+
+ def configure_text_model(self, device=get_device()):
+ self.text_encoder = T5EncoderModel(
+ text_len=self.config.dit.model.text_len,
+ dtype=torch.bfloat16,
+ device=device,
+ checkpoint_path=self.config.text.t5_checkpoint,
+ tokenizer_path=self.config.text.t5_tokenizer,
+ )
+
+
+ def configure_dit_fsdp_model(self):
+ self.dit.to(get_device())
+
+ return
+
+
+ def configure_text_fsdp_model(self):
+ self.text_encoder.to(get_device())
+
+ return
+
+
+ def load_image_latent_ref_id(self, path: str, size, device):
+ # Load size.
+ h, w = size[1], size[0]
+
+ # Load image.
+ if len(path) > 1 and not isinstance(path, str):
+ ref_vae_latents = []
+ for image_path in path:
+ with Image.open(image_path) as img:
+ img = img.convert("RGB")
+
+ # Calculate the required size to keep aspect ratio and fill the rest with padding.
+ img_ratio = img.width / img.height
+ target_ratio = w / h
+
+ if img_ratio > target_ratio: # Image is wider than target
+ new_width = w
+ new_height = int(new_width / img_ratio)
+ else: # Image is taller than target
+ new_height = h
+ new_width = int(new_height * img_ratio)
+
+ # img = img.resize((new_width, new_height), Image.ANTIALIAS)
+ img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
+
+ # Create a new image with the target size and place the resized image in the center
+ delta_w = w - img.size[0]
+ delta_h = h - img.size[1]
+ padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
+ new_img = ImageOps.expand(img, padding, fill=(255, 255, 255))
+
+ # Transform to tensor and normalize.
+ transform = Compose(
+ [
+ ToTensor(),
+ Normalize(0.5, 0.5),
+ ]
+ )
+ new_img = transform(new_img)
+ # img_vae_latent = self.vae_encode([new_img.unsqueeze(1)])[0]
+ img_vae_latent = self.vae.encode([new_img.unsqueeze(1)], device)
+ ref_vae_latents.append(img_vae_latent[0])
+
+ return [torch.cat(ref_vae_latents, dim=1)]
+ else:
+ if not isinstance(path, str):
+ path = path[0]
+ with Image.open(path) as img:
+ img = img.convert("RGB")
+
+ # Calculate the required size to keep aspect ratio and fill the rest with padding.
+ img_ratio = img.width / img.height
+ target_ratio = w / h
+
+ if img_ratio > target_ratio: # Image is wider than target
+ new_width = w
+ new_height = int(new_width / img_ratio)
+ else: # Image is taller than target
+ new_height = h
+ new_width = int(new_height * img_ratio)
+
+ # img = img.resize((new_width, new_height), Image.ANTIALIAS)
+ img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
+
+ # Create a new image with the target size and place the resized image in the center
+ delta_w = w - img.size[0]
+ delta_h = h - img.size[1]
+ padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
+ new_img = ImageOps.expand(img, padding, fill=(255, 255, 255))
+
+ # Transform to tensor and normalize.
+ transform = Compose(
+ [
+ ToTensor(),
+ Normalize(0.5, 0.5),
+ ]
+ )
+ new_img = transform(new_img)
+ img_vae_latent = self.vae.encode([new_img.unsqueeze(1)], device)
+
+ # Vae encode.
+ return img_vae_latent
+
+ def get_audio_emb_window(self, audio_emb, frame_num, frame0_idx, audio_shift=2):
+ zero_audio_embed = torch.zeros((audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device)
+ zero_audio_embed_3 = torch.zeros((3, audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device) # device=audio_emb.device
+ iter_ = 1 + (frame_num - 1) // 4
+ audio_emb_wind = []
+ for lt_i in range(iter_):
+ if lt_i == 0:
+ st = frame0_idx + lt_i - 2
+ ed = frame0_idx + lt_i + 3
+ wind_feat = torch.stack([
+ audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed
+ for i in range(st, ed)
+ ], dim=0)
+ wind_feat = torch.cat((zero_audio_embed_3, wind_feat), dim=0)
+ else:
+ st = frame0_idx + 1 + 4 * (lt_i - 1) - audio_shift
+ ed = frame0_idx + 1 + 4 * lt_i + audio_shift
+ wind_feat = torch.stack([
+ audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed
+ for i in range(st, ed)
+ ], dim=0)
+ audio_emb_wind.append(wind_feat)
+ audio_emb_wind = torch.stack(audio_emb_wind, dim=0)
+
+ return audio_emb_wind, ed - audio_shift
+
+ def audio_emb_enc(self, audio_emb, wav_enc_type="whisper"):
+ if wav_enc_type == "wav2vec":
+ feat_merge = audio_emb
+ elif wav_enc_type == "whisper":
+ feat0 = linear_interpolation_fps(audio_emb[:, :, 0: 8].mean(dim=2), 50, 25)
+ feat1 = linear_interpolation_fps(audio_emb[:, :, 8: 16].mean(dim=2), 50, 25)
+ feat2 = linear_interpolation_fps(audio_emb[:, :, 16: 24].mean(dim=2), 50, 25)
+ feat3 = linear_interpolation_fps(audio_emb[:, :, 24: 32].mean(dim=2), 50, 25)
+ feat4 = linear_interpolation_fps(audio_emb[:, :, 32], 50, 25)
+ feat_merge = torch.stack([feat0, feat1, feat2, feat3, feat4], dim=2)[0]
+ else:
+ raise ValueError(f"Unsupported wav_enc_type: {wav_enc_type}")
+
+ return feat_merge
+
+ def forward_tia(self, latents, latents_ref, latents_ref_neg, timestep, arg_t, arg_ta, arg_null):
+ neg = self.dit(
+ [torch.cat([latent[:,:-latent_ref_neg.shape[1]], latent_ref_neg], dim=1) for latent, latent_ref_neg in zip(latents, latents_ref_neg)], t=timestep, **arg_null
+ )[0]
+
+ pos_t = self.dit(
+ [torch.cat([latent[:,:-latent_ref_neg.shape[1]], latent_ref_neg], dim=1) for latent, latent_ref_neg in zip(latents, latents_ref_neg)], t=timestep, **arg_t
+ )[0]
+ pos_ta = self.dit(
+ [torch.cat([latent[:,:-latent_ref_neg.shape[1]], latent_ref_neg], dim=1) for latent, latent_ref_neg in zip(latents, latents_ref_neg)], t=timestep, **arg_ta
+ )[0]
+ pos_tia = self.dit(
+ [torch.cat([latent[:,:-latent_ref.shape[1]], latent_ref], dim=1) for latent, latent_ref in zip(latents, latents_ref)], t=timestep, **arg_ta
+ )[0]
+
+ noise_pred = self.config.generation.scale_i * (pos_tia - pos_ta) + \
+ self.config.generation.scale_a * (pos_ta - pos_t) + \
+ self.config.generation.scale_t * (pos_t - neg) + \
+ neg
+
+ return noise_pred
+
+ def forward_ta(self, latents, latents_ref_neg, timestep, arg_t, arg_ta, arg_null):
+ neg = self.dit(
+ [torch.cat([latent[:,:-latent_ref_neg.shape[1]], latent_ref_neg], dim=1) for latent, latent_ref_neg in zip(latents, latents_ref_neg)], t=timestep, **arg_null
+ )[0]
+
+ pos_t = self.dit(
+ [torch.cat([latent[:,:-latent_ref_neg.shape[1]], latent_ref_neg], dim=1) for latent, latent_ref_neg in zip(latents, latents_ref_neg)], t=timestep, **arg_t
+ )[0]
+ pos_ta = self.dit(
+ [torch.cat([latent[:,:-latent_ref_neg.shape[1]], latent_ref_neg], dim=1) for latent, latent_ref_neg in zip(latents, latents_ref_neg)], t=timestep, **arg_ta
+ )[0]
+
+ noise_pred = self.config.generation.scale_a * (pos_ta - pos_t) + \
+ self.config.generation.scale_t * (pos_t - neg) + \
+ neg
+
+ return noise_pred
+
+
+ @torch.no_grad()
+ def inference(self,
+ input_prompt,
+ img_path,
+ audio_path,
+ size=(1280, 720),
+ frame_num=81,
+ shift=5.0,
+ sample_solver='unipc',
+ sampling_steps=50,
+ n_prompt="",
+ seed=-1,
+ offload_model=True,
+ device = get_device(),
+ ):
+
+ self.vae.model.to(device=device)
+ if img_path is not None:
+ latents_ref = self.load_image_latent_ref_id(img_path, size, device)
+ else:
+ latents_ref = [torch.zeros(16, 1, size[1]//8, size[0]//8).to(device)]
+
+ self.vae.model.to(device="cpu")
+ latents_ref_neg = [torch.zeros_like(latent_ref) for latent_ref in latents_ref]
+
+ # audio
+ if audio_path is not None:
+ if self.config.generation.extract_audio_feat:
+ self.audio_processor.whisper.to(device=device)
+ audio_emb, audio_length = self.audio_processor.preprocess(audio_path)
+ self.audio_processor.whisper.to(device='cpu')
+ else:
+ audio_emb_path = audio_path.replace(".wav", ".pt")
+ audio_emb = torch.load(audio_emb_path).to(device=device)
+ audio_emb = self.audio_emb_enc(audio_emb, wav_enc_type="whisper")
+ self.logger.info("使用预先提取好的音频特征: %s", audio_emb_path)
+ else:
+ audio_emb = torch.zeros(frame_num, 5, 1280).to(device)
+
+ frame_num = frame_num if frame_num != -1 else audio_length
+ frame_num = 4 * ((frame_num - 1) // 4) + 1
+ audio_emb, _ = self.get_audio_emb_window(audio_emb, frame_num, frame0_idx=0)
+ zero_audio_pad = torch.zeros(latents_ref[0].shape[1], *audio_emb.shape[1:]).to(audio_emb.device)
+ audio_emb = torch.cat([audio_emb, zero_audio_pad], dim=0)
+ audio_emb = [audio_emb.to(device)]
+ audio_emb_neg = [torch.zeros_like(audio_emb[0])]
+
+ # preprocess
+ self.patch_size = self.config.dit.model.patch_size
+ F = frame_num
+ target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1 + latents_ref[0].shape[1],
+ size[1] // self.vae_stride[1],
+ size[0] // self.vae_stride[2])
+
+ seq_len = math.ceil((target_shape[2] * target_shape[3]) /
+ (self.patch_size[1] * self.patch_size[2]) *
+ target_shape[1] / self.sp_size) * self.sp_size
+
+ if n_prompt == "":
+ n_prompt = self.config.generation.sample_neg_prompt
+ seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
+ seed_g = torch.Generator(device=device)
+ seed_g.manual_seed(seed)
+
+ self.text_encoder.model.to(device)
+ context = self.text_encoder([input_prompt], device)
+ context_null = self.text_encoder([n_prompt], device)
+ self.text_encoder.model.cpu()
+
+ noise = [
+ torch.randn(
+ target_shape[0],
+ target_shape[1], # - latents_ref[0].shape[1],
+ target_shape[2],
+ target_shape[3],
+ dtype=torch.float32,
+ device=device,
+ generator=seed_g)
+ ]
+
+ @contextmanager
+ def noop_no_sync():
+ yield
+
+ no_sync = getattr(self.dit, 'no_sync', noop_no_sync)
+ # step_change = self.config.generation.step_change # 980
+
+ # evaluation mode
+ with amp.autocast(dtype=torch.bfloat16), torch.no_grad(), no_sync():
+
+ if sample_solver == 'unipc':
+ sample_scheduler = FlowUniPCMultistepScheduler(
+ num_train_timesteps=1000,
+ shift=1,
+ use_dynamic_shifting=False)
+ sample_scheduler.set_timesteps(
+ sampling_steps, device=device, shift=shift)
+ timesteps = sample_scheduler.timesteps
+
+ # sample videos
+ latents = noise
+
+ # referene image在下面的输入中手动指定, 不在arg中指定
+ arg_ta = {'context': context, 'seq_len': seq_len, 'audio': audio_emb}
+ arg_t = {'context': context, 'seq_len': seq_len, 'audio': audio_emb_neg}
+ arg_null = {'context': context_null, 'seq_len': seq_len, 'audio': audio_emb_neg}
+
+ torch.cuda.empty_cache()
+ self.dit.to(device=get_device())
+ for _, t in enumerate(tqdm(timesteps)):
+ timestep = [t]
+ timestep = torch.stack(timestep)
+
+ if self.config.generation.mode == "TIA":
+ noise_pred = self.forward_tia(latents, latents_ref, latents_ref_neg, timestep, arg_t, arg_ta, arg_null)
+ elif self.config.generation.mode == "TA":
+ noise_pred = self.forward_ta(latents, latents_ref_neg, timestep, arg_t, arg_ta, arg_null)
+ else:
+ raise ValueError(f"Unsupported generation mode: {self.config.generation.mode}")
+
+ temp_x0 = sample_scheduler.step(
+ noise_pred.unsqueeze(0),
+ t,
+ latents[0].unsqueeze(0),
+ return_dict=False,
+ generator=seed_g)[0]
+ latents = [temp_x0.squeeze(0)]
+
+ del timestep
+ torch.cuda.empty_cache()
+
+ x0 = latents
+ x0 = [x0_[:,:-latents_ref[0].shape[1]] for x0_ in x0]
+
+ # if offload_model:
+ self.dit.cpu()
+ torch.cuda.empty_cache()
+ # if get_local_rank() == 0:
+ self.vae.model.to(device=device)
+ videos = self.vae.decode(x0)
+ self.vae.model.to(device="cpu")
+
+ del noise, latents, noise_pred
+ del audio_emb, audio_emb_neg, latents_ref, latents_ref_neg, context, context_null
+ del x0, temp_x0
+ del sample_scheduler
+ torch.cuda.empty_cache()
+ gc.collect()
+ torch.cuda.synchronize()
+ if dist.is_initialized():
+ dist.barrier()
+
+ return videos[0] # if get_local_rank() == 0 else None
+
+
+ def inference_loop(self, prompt, ref_img_path, audio_path, output_dir, filename, width = 832, height = 480, steps=50, frames = 97, seed = 0):
+ print(f'ref_img_path:{ref_img_path}')
+
+ video = self.inference(
+ prompt,
+ ref_img_path,
+ audio_path,
+ size=SIZE_CONFIGS[f"{width}*{height}"],
+ frame_num=frames,
+ shift=self.config.diffusion.timesteps.sampling.shift,
+ sample_solver='unipc',
+ sampling_steps=steps,
+ seed=seed,
+ offload_model=False,
+ )
+
+ torch.cuda.empty_cache()
+ gc.collect()
+
+
+ # Save samples.
+ if get_sequence_parallel_rank() == 0:
+ pathname = self.save_sample(
+ sample=video,
+ audio_path=audio_path,
+ output_dir = output_dir,
+ filename=filename,
+ )
+ self.logger.info(f"Finished {filename}, saved to {pathname}.")
+
+ del video, prompt
+ torch.cuda.empty_cache()
+ gc.collect()
+
+
+
+ def save_sample(self, *, sample: torch.Tensor, audio_path: str, output_dir: str, filename: str):
+ gen_config = self.config.generation
+ # Prepare file path.
+ extension = ".mp4" if sample.ndim == 4 else ".png"
+ filename += extension
+ pathname = os.path.join(output_dir, filename)
+ # Convert sample.
+ sample = sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).to("cpu", torch.uint8)
+ sample = rearrange(sample, "c t h w -> t h w c")
+ # Save file.
+ if sample.ndim == 4:
+ if audio_path is not None:
+ tensor_to_video(
+ sample.numpy(),
+ pathname,
+ audio_path,
+ fps=gen_config.fps)
+ else:
+ mediapy.write_video(
+ path=pathname,
+ images=sample.numpy(),
+ fps=gen_config.fps,
+ )
+ else:
+ raise ValueError
+ return pathname
+
+
+ def prepare_positive_prompts(self):
+ pos_prompts = self.config.generation.positive_prompt
+ if pos_prompts.endswith(".json"):
+ pos_prompts = prepare_json_dataset(pos_prompts)
+ else:
+ raise NotImplementedError
+ assert isinstance(pos_prompts, ListConfig)
+
+ return pos_prompts
\ No newline at end of file
diff --git a/humo/models/audio/audio_proj.py b/humo/models/audio/audio_proj.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d4771dad1e648fb063c30a18258d258a4739dc4
--- /dev/null
+++ b/humo/models/audio/audio_proj.py
@@ -0,0 +1,87 @@
+import torch
+from einops import rearrange
+from torch import nn
+from einops import rearrange
+
+class WanRMSNorm(nn.Module):
+
+ def __init__(self, dim, eps=1e-5):
+ super().__init__()
+ self.dim = dim
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(dim))
+
+ def forward(self, x):
+ r"""
+ Args:
+ x(Tensor): Shape [B, L, C]
+ """
+ return self._norm(x.float()).type_as(x) * self.weight
+
+ def _norm(self, x):
+ return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
+
+
+class DummyAdapterLayer(nn.Module):
+ def __init__(self, layer):
+ super().__init__()
+ self.layer = layer
+
+ def forward(self, *args, **kwargs):
+ return self.layer(*args, **kwargs)
+
+
+class AudioProjModel(nn.Module):
+ def __init__(
+ self,
+ seq_len=5,
+ blocks=13, # add a new parameter blocks
+ channels=768, # add a new parameter channels
+ intermediate_dim=512,
+ output_dim=1536,
+ context_tokens=16,
+ ):
+ super().__init__()
+
+ self.seq_len = seq_len
+ self.blocks = blocks
+ self.channels = channels
+ self.input_dim = seq_len * blocks * channels # update input_dim to be the product of blocks and channels.
+ self.intermediate_dim = intermediate_dim
+ self.context_tokens = context_tokens
+ self.output_dim = output_dim
+
+ # define multiple linear layers
+ self.audio_proj_glob_1 = DummyAdapterLayer(nn.Linear(self.input_dim, intermediate_dim))
+ self.audio_proj_glob_2 = DummyAdapterLayer(nn.Linear(intermediate_dim, intermediate_dim))
+ self.audio_proj_glob_3 = DummyAdapterLayer(nn.Linear(intermediate_dim, context_tokens * output_dim))
+
+ self.audio_proj_glob_norm = DummyAdapterLayer(nn.LayerNorm(output_dim))
+
+ self.initialize_weights()
+
+ def initialize_weights(self):
+ # Initialize transformer layers:
+ def _basic_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+
+ self.apply(_basic_init)
+
+ def forward(self, audio_embeds):
+ video_length = audio_embeds.shape[1]
+ audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c")
+ batch_size, window_size, blocks, channels = audio_embeds.shape
+ audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels)
+
+ audio_embeds = torch.relu(self.audio_proj_glob_1(audio_embeds))
+ audio_embeds = torch.relu(self.audio_proj_glob_2(audio_embeds))
+
+ context_tokens = self.audio_proj_glob_3(audio_embeds).reshape(batch_size, self.context_tokens, self.output_dim)
+
+ context_tokens = self.audio_proj_glob_norm(context_tokens)
+ context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length)
+
+ return context_tokens
\ No newline at end of file
diff --git a/humo/models/distributed/__init__.py b/humo/models/distributed/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/humo/models/distributed/dit_ulysses_sequence_parallel.py b/humo/models/distributed/dit_ulysses_sequence_parallel.py
new file mode 100644
index 0000000000000000000000000000000000000000..3da02c08b488bd5130e9879ed024b64ae09287d5
--- /dev/null
+++ b/humo/models/distributed/dit_ulysses_sequence_parallel.py
@@ -0,0 +1,270 @@
+# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+from einops import rearrange
+from common.distributed import get_device
+
+from common.distributed.advanced import (
+ get_unified_parallel_world_size,
+ get_unified_parallel_group,
+ pad_tensor,
+ Slice,
+ gather_outputs,
+ gather_seq_scatter_heads_qkv,
+ gather_seq_scatter_double_head,
+ gather_heads_scatter_seq,
+ unpad_tensor
+)
+from humo.models.wan_modules.attention import flash_attention
+from humo.models.wan_modules.model_humo import rope_apply, sinusoidal_embedding_1d
+
+
+def ulysses_dit_forward(
+ self,
+ x,
+ t,
+ context,
+ seq_len,
+ audio=None,
+ y=None
+):
+ """
+ x: A list of videos each with shape [C, T, H, W].
+ t: [B].
+ context: A list of text embeddings each with shape [L, C].
+ """
+ if self.model_type == 'i2v':
+ # assert clip_fea is not None and y is not None
+ assert y is not None
+ # params
+ device = self.patch_embedding.weight.device
+ if self.freqs.device != device:
+ self.freqs = self.freqs.to(device)
+
+ if y is not None:
+ x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
+
+ # embeddings
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
+ grid_sizes = torch.stack(
+ [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
+ x = [u.flatten(2).transpose(1, 2) for u in x]
+ seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long, device=device)
+
+ assert seq_lens.max() <= seq_len
+ x = torch.cat([
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
+ for u in x
+ ])
+
+ # time embeddings
+ with torch.amp.autocast('cuda', dtype=torch.float32):
+ e = self.time_embedding(
+ sinusoidal_embedding_1d(self.freq_dim, t).float()).float()
+ e0 = self.time_projection(e).unflatten(1, (6, self.dim)).float()
+ assert e.dtype == torch.float32 and e0.dtype == torch.float32
+
+ # context
+ context_lens = None
+ context = self.text_embedding(
+ torch.stack([
+ torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
+ for u in context
+ ]))
+
+ if self.insert_audio:
+ audio = [self.audio_proj(au.unsqueeze(0)).permute(0, 3, 1, 2) for au in audio]
+
+ audio_seq_len = torch.tensor(max([au.shape[2] for au in audio]) * audio[0].shape[3], device=get_device())
+ audio = [au.flatten(2).transpose(1, 2) for au in audio] # [1, t*32, 1536]
+ audio_seq_lens = torch.tensor([au.size(1) for au in audio], dtype=torch.long, device=device)
+ audio = torch.cat([
+ torch.cat([au, au.new_zeros(1, audio_seq_len - au.size(1), au.size(2))],
+ dim=1) for au in audio
+ ])
+ else:
+ audio = None
+ audio_seq_len = None
+ audio_seq_lens = None
+
+ # ulysses support
+ sp_world = get_unified_parallel_world_size()
+ group = get_unified_parallel_group()
+ if seq_len % sp_world:
+ padding_size = sp_world - (seq_len % sp_world)
+ x = pad_tensor(x, dim=1, padding_size=padding_size)
+
+ if self.insert_audio:
+ audio_padding_size = sp_world - (audio_seq_len % sp_world)
+ audio = pad_tensor(audio, dim=1, padding_size=audio_padding_size)
+
+ x = Slice.apply(group, x, 1, True)
+
+ if self.insert_audio:
+ audio = Slice.apply(group, audio, 1, True)
+
+ # arguments
+ kwargs = dict(
+ e=e0,
+ seq_lens=seq_lens,
+ grid_sizes=grid_sizes,
+ freqs=self.freqs,
+ context=context,
+ context_lens=context_lens,
+ audio=audio,
+ audio_seq_len=audio_seq_len)
+
+ for block in self.blocks:
+ x = block(x, **kwargs)
+
+ # head
+ x = self.head(x, e)
+
+ # ulysses support
+ x = gather_outputs(x, gather_dim=1, padding_dim=1, unpad_dim_size=seq_len, scale_grad=True)
+
+ # unpatchify
+ x = self.unpatchify(x, grid_sizes)
+ return [u.float() for u in x]
+
+
+def ulysses_attn_forward(
+ self,
+ x,
+ seq_lens,
+ grid_sizes,
+ freqs,
+ dtype=torch.bfloat16
+):
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
+ seq_len = seq_lens.max()
+ half_dtypes = (torch.float16, torch.bfloat16)
+
+ def half(x):
+ return x if x.dtype in half_dtypes else x.to(dtype)
+
+ # query, key, value function
+ def qkv_fn(x):
+ q = self.norm_q(self.q(x))
+ k = self.norm_k(self.k(x))
+ v = self.v(x)
+ return q, k, v
+
+ q, k, v = qkv_fn(x)
+
+ # ulysses support
+ sp_size = get_unified_parallel_world_size()
+ if n % sp_size:
+ pad_size = sp_size - (n % sp_size)
+ pad_size = pad_size * d
+ pad_inner_dim = n * d + pad_size
+ q = pad_tensor(q, dim=2, padding_size=pad_size)
+ k = pad_tensor(k, dim=2, padding_size=pad_size)
+ v = pad_tensor(v, dim=2, padding_size=pad_size)
+ else:
+ pad_inner_dim = n * d
+
+ qkv = torch.cat([q, k, v], dim=2)
+ qkv = gather_seq_scatter_heads_qkv(qkv, seq_dim=1, unpadded_dim_size=seq_len)
+ q, k, v = qkv.split(pad_inner_dim // sp_size, dim=2)
+
+ pad_n = pad_inner_dim // d
+ pad_split_n = pad_n // sp_size
+ q = q.view(b, seq_len, pad_split_n, d)
+ k = k.view(b, seq_len, pad_split_n, d)
+ v = v.view(b, seq_len, pad_split_n, d)
+
+ q = rope_apply(q, grid_sizes, freqs)
+ k = rope_apply(k, grid_sizes, freqs)
+
+ x = flash_attention(
+ q=half(q),
+ k=half(k),
+ v=half(v),
+ k_lens=seq_lens,
+ window_size=self.window_size
+ )
+
+ # ulysses support
+ x = x.flatten(2)
+ x = gather_heads_scatter_seq(x, head_dim=2, seq_dim=1)
+ if n % sp_size:
+ x = unpad_tensor(x, dim=2, unpad_dim_size=seq_len)
+
+ x = self.o(x)
+ return x
+
+
+def ulysses_audio_cross_attn_forward(
+ self,
+ x,
+ audio,
+ seq_lens,
+ grid_sizes,
+ freqs,
+ audio_seq_len,
+ dtype=torch.bfloat16
+):
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
+ seq_len = seq_lens.max()
+
+ q = self.norm_q(self.q(x))
+ k = self.norm_k(self.k(audio))
+ v = self.v(audio)
+
+ # ulysses support
+ sp_size = get_unified_parallel_world_size()
+ if n % sp_size:
+ pad_size = sp_size - (n % sp_size)
+ pad_size = pad_size * d
+ pad_inner_dim = n * d + pad_size
+ q = pad_tensor(q, dim=2, padding_size=pad_size)
+ k = pad_tensor(k, dim=2, padding_size=pad_size)
+ v = pad_tensor(v, dim=2, padding_size=pad_size)
+ else:
+ pad_inner_dim = n * d
+
+ qq = torch.cat([q, q], dim=2)
+ kv = torch.cat([k, v], dim=2)
+ qq = gather_seq_scatter_double_head(qq, seq_dim=1, unpadded_dim_size=seq_len)
+ kv = gather_seq_scatter_double_head(kv, seq_dim=1, unpadded_dim_size=audio_seq_len)
+ q, _ = qq.split(pad_inner_dim // sp_size, dim=2)
+ k, v = kv.split(pad_inner_dim // sp_size, dim=2)
+
+ pad_n = pad_inner_dim // d
+ pad_split_n = pad_n // sp_size
+ q = q.view(b, seq_len, pad_split_n, d)
+ k = k.view(b, audio_seq_len, pad_split_n, d)
+ v = v.view(b, audio_seq_len, pad_split_n, d)
+
+ hlen_wlen = int(grid_sizes[0][1] * grid_sizes[0][2])
+ assert hlen_wlen == 1560 or hlen_wlen == 3600
+ q = q.reshape(-1, hlen_wlen, pad_split_n, d)
+ k = k.reshape(-1, 16, pad_split_n, d)
+ v = v.reshape(-1, 16, pad_split_n, d)
+
+ x = flash_attention(
+ q=q,
+ k=k,
+ v=v,
+ k_lens=None,
+ )
+ x = x.view(b, -1, pad_split_n, d)
+
+ # ulysses support
+ x = x.flatten(2)
+ x = gather_heads_scatter_seq(x, head_dim=2, seq_dim=1)
+ if n % sp_size:
+ x = unpad_tensor(x, dim=2, unpad_dim_size=seq_len)
+
+ x = self.o(x)
+ return x
diff --git a/humo/models/distributed/fsdp.py b/humo/models/distributed/fsdp.py
new file mode 100644
index 0000000000000000000000000000000000000000..51c81e7700577e07ef1a2ebf7f26cc02ad269cd7
--- /dev/null
+++ b/humo/models/distributed/fsdp.py
@@ -0,0 +1,42 @@
+# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from functools import partial
+
+import torch
+from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
+from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
+
+
+def shard_model(
+ model,
+ device_id,
+ param_dtype=torch.bfloat16,
+ reduce_dtype=torch.float32,
+ buffer_dtype=torch.float32,
+ process_group=None,
+ sharding_strategy=ShardingStrategy.FULL_SHARD,
+ sync_module_states=True,
+):
+ model = FSDP(
+ module=model,
+ process_group=process_group,
+ sharding_strategy=sharding_strategy,
+ auto_wrap_policy=partial(
+ lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks),
+ mixed_precision=MixedPrecision(
+ param_dtype=param_dtype,
+ reduce_dtype=reduce_dtype,
+ buffer_dtype=buffer_dtype),
+ device_id=device_id,
+ sync_module_states=sync_module_states)
+ return model
diff --git a/humo/models/text/encoder.py b/humo/models/text/encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b444bb7e778936be7f874947942780bcf3cfcfc
--- /dev/null
+++ b/humo/models/text/encoder.py
@@ -0,0 +1,173 @@
+import os
+from dataclasses import dataclass
+from typing import List, Optional, Union
+import torch
+from omegaconf import DictConfig, OmegaConf
+from torch import nn
+from transformers import (
+ AutoModelForCausalLM,
+ AutoTokenizer,
+ CLIPTextModel,
+ CLIPTokenizerFast,
+ T5EncoderModel,
+ T5TokenizerFast,
+)
+from transformers.tokenization_utils_base import BatchEncoding
+
+from common.fs import download_and_extract
+from common.logger import get_logger
+
+logger = get_logger(__name__)
+
+MODEL_TYPES = {
+ "clip": (CLIPTokenizerFast, CLIPTextModel),
+ "t5": (T5TokenizerFast, T5EncoderModel),
+ "llm14b": (AutoTokenizer, AutoModelForCausalLM),
+}
+
+
+@dataclass
+class TextEncoderOutput:
+ embeddings: Union[torch.FloatTensor, List[torch.FloatTensor]]
+ masks: Union[torch.BoolTensor, List[torch.BoolTensor]]
+ pooled: Optional[Union[torch.FloatTensor, List[torch.FloatTensor]]]
+
+
+class TextEncoder(nn.Module):
+ def __init__(self, config: DictConfig):
+ super().__init__()
+ self.config = config
+ self.tokenizers = []
+ self.models = nn.ModuleList([])
+
+ # Disable tokenizer parallelism since we already use distributed training.
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
+
+ for model in config.models:
+ tokenizer_cls, model_cls = MODEL_TYPES[model.type]
+ path = download_and_extract(model.path)
+ max_length = model.max_length
+
+ if model.type == "llm14b":
+ tokenizer = tokenizer_cls.from_pretrained(
+ path,
+ model_max_length=max_length,
+ use_fast=False,
+ trust_remote_code=True,
+ padding_side="right",
+ truncation_side="right",
+ add_eod_token=True,
+ )
+ tokenizer.add_special_tokens({"pad_token": "<|endoftext|>"})
+ model = model_cls.from_pretrained(path, trust_remote_code=True, bf16=True)
+ else:
+ tokenizer = tokenizer_cls.from_pretrained(path, model_max_length=max_length)
+ model = model_cls.from_pretrained(path, torch_dtype=torch.bfloat16)
+ self.tokenizers.append(tokenizer)
+ self.models.append(model)
+
+ def forward(self, text: Union[str, List[str]]) -> TextEncoderOutput:
+ embeddings, masks, pooled = [], [], []
+
+ for encoder_config, tokenizer, model in zip(
+ self.config.models, self.tokenizers, self.models
+ ):
+ if encoder_config.type == "llm14b":
+ use_mask = encoder_config.get("mask", True)
+ tokens = tokenizer(
+ text,
+ return_tensors="pt",
+ padding="max_length",
+ truncation=True,
+ ).to(model.device)
+ token_ids = tokens["input_ids"]
+ attention_mask = tokens["attention_mask"]
+ num_tokens = attention_mask.sum(dim=1)
+ range_ids = torch.arange(len(token_ids), device=token_ids.device, dtype=torch.long)
+ token_ids[range_ids, num_tokens.clamp(max=token_ids.size(1) - 1)] = (
+ tokenizer.pad_token_id
+ )
+ attention_mask[range_ids, num_tokens.clamp(max=token_ids.size(1) - 1)] = 1
+ tokens = BatchEncoding({"input_ids": token_ids, "attention_mask": attention_mask})
+ output = model.transformer(
+ input_ids=tokens.input_ids,
+ attention_mask=attention_mask if use_mask else None,
+ output_hidden_states=False,
+ use_cache=False,
+ )
+ emb = output.last_hidden_state # batch_size, num_tokens, feat_dim
+ # emb *= tokens.attention_mask.unsqueeze(-1)
+
+ embeddings.append(emb)
+ masks.append(
+ tokens.attention_mask.bool() if use_mask else tokens.attention_mask > -1
+ )
+
+ else:
+ # Tokenizer
+ tokens = tokenizer(
+ text=text,
+ truncation=True,
+ padding="max_length",
+ return_tensors="pt",
+ )
+
+ # Encoder
+ use_mask = encoder_config.get("mask", True)
+ input_ids = tokens.input_ids.to(model.device)
+ attention_mask = tokens.attention_mask.to(model.device)
+ output = model(
+ input_ids=input_ids,
+ attention_mask=attention_mask if use_mask else None,
+ output_hidden_states=True,
+ )
+
+ # Save embeddings from the defined layer.
+ layer = encoder_config.get("layer", "last")
+ if layer == "last":
+ embeddings.append(output.last_hidden_state)
+ elif layer == "penultimate":
+ embeddings.append(model.text_model.final_layer_norm(output.hidden_states[-2]))
+ elif layer == "penultimate_nonorm":
+ embeddings.append(output.hidden_states[-2])
+ else:
+ raise NotImplementedError(f"Unknown layer type: {layer}.")
+
+ # Save masks
+ masks.append(attention_mask.bool() if use_mask else attention_mask > -1)
+
+ # Save pooled output if available.
+ if hasattr(output, "pooler_output"):
+ pooled.append(output.pooler_output)
+
+ output_config = self.config.get("output") or OmegaConf.create()
+ embedding_output_type = output_config.get("embedding_and_mask", "undefined")
+ pooled_output_type = output_config.get("pooled", "undefined")
+
+ # Select or merge embeddings and mask if needed.
+ if embedding_output_type == "undefined" and len(self.models) == 1:
+ embeddings = embeddings[0]
+ masks = masks[0]
+ elif embedding_output_type == "channel_concat":
+ embeddings = torch.cat(embeddings, dim=-1)
+ masks = sum(masks).bool()
+ elif embedding_output_type == "last":
+ embeddings = embeddings[-1]
+ masks = masks[-1]
+ else:
+ raise NotImplementedError(f"output.embedding_and_mask: {embedding_output_type}")
+
+ # Select or merge pooled output if needed.
+ if pooled_output_type == "undefined":
+ pooled = None
+ elif pooled_output_type == "channel_concat":
+ pooled = torch.cat(pooled, dim=-1)
+ elif pooled_output_type == "first":
+ pooled = pooled[0]
+ elif pooled_output_type == "last":
+ pooled = pooled[-1]
+ else:
+ raise NotImplementedError(f"output.pooled: {pooled_output_type}")
+
+ # Return final results.
+ return TextEncoderOutput(embeddings, masks, pooled)
diff --git a/humo/models/utils/fm_solvers.py b/humo/models/utils/fm_solvers.py
new file mode 100644
index 0000000000000000000000000000000000000000..1271b678bd0b1287adebbc344ef5fbfa952a9558
--- /dev/null
+++ b/humo/models/utils/fm_solvers.py
@@ -0,0 +1,857 @@
+# Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
+# Convert dpm solver for flow matching
+# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
+
+import inspect
+import math
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.schedulers.scheduling_utils import (KarrasDiffusionSchedulers,
+ SchedulerMixin,
+ SchedulerOutput)
+from diffusers.utils import deprecate, is_scipy_available
+from diffusers.utils.torch_utils import randn_tensor
+
+if is_scipy_available():
+ pass
+
+
+def get_sampling_sigmas(sampling_steps, shift):
+ sigma = np.linspace(1, 0, sampling_steps + 1)[:sampling_steps]
+ sigma = (shift * sigma / (1 + (shift - 1) * sigma))
+
+ return sigma
+
+
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps=None,
+ device=None,
+ timesteps=None,
+ sigmas=None,
+ **kwargs,
+):
+ if timesteps is not None and sigmas is not None:
+ raise ValueError(
+ "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
+ )
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(
+ inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(
+ inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class FlowDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
+ """
+ `FlowDPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs.
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
+ methods the library implements for all schedulers such as loading and saving.
+ Args:
+ num_train_timesteps (`int`, defaults to 1000):
+ The number of diffusion steps to train the model. This determines the resolution of the diffusion process.
+ solver_order (`int`, defaults to 2):
+ The DPMSolver order which can be `1`, `2`, or `3`. It is recommended to use `solver_order=2` for guided
+ sampling, and `solver_order=3` for unconditional sampling. This affects the number of model outputs stored
+ and used in multistep updates.
+ prediction_type (`str`, defaults to "flow_prediction"):
+ Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts
+ the flow of the diffusion process.
+ shift (`float`, *optional*, defaults to 1.0):
+ A factor used to adjust the sigmas in the noise schedule. It modifies the step sizes during the sampling
+ process.
+ use_dynamic_shifting (`bool`, defaults to `False`):
+ Whether to apply dynamic shifting to the timesteps based on image resolution. If `True`, the shifting is
+ applied on the fly.
+ thresholding (`bool`, defaults to `False`):
+ Whether to use the "dynamic thresholding" method. This method adjusts the predicted sample to prevent
+ saturation and improve photorealism.
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
+ sample_max_value (`float`, defaults to 1.0):
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
+ `algorithm_type="dpmsolver++"`.
+ algorithm_type (`str`, defaults to `dpmsolver++`):
+ Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The
+ `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927)
+ paper, and the `dpmsolver++` type implements the algorithms in the
+ [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or
+ `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
+ solver_type (`str`, defaults to `midpoint`):
+ Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
+ sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
+ lower_order_final (`bool`, defaults to `True`):
+ Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
+ stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
+ euler_at_final (`bool`, defaults to `False`):
+ Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail
+ richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference
+ steps, but sometimes may result in blurring.
+ final_sigmas_type (`str`, *optional*, defaults to "zero"):
+ The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
+ sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
+ lambda_min_clipped (`float`, defaults to `-inf`):
+ Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
+ cosine (`squaredcos_cap_v2`) noise schedule.
+ variance_type (`str`, *optional*):
+ Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output
+ contains the predicted Gaussian variance.
+ """
+
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ solver_order: int = 2,
+ prediction_type: str = "flow_prediction",
+ shift: Optional[float] = 1.0,
+ use_dynamic_shifting=False,
+ thresholding: bool = False,
+ dynamic_thresholding_ratio: float = 0.995,
+ sample_max_value: float = 1.0,
+ algorithm_type: str = "dpmsolver++",
+ solver_type: str = "midpoint",
+ lower_order_final: bool = True,
+ euler_at_final: bool = False,
+ final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
+ lambda_min_clipped: float = -float("inf"),
+ variance_type: Optional[str] = None,
+ invert_sigmas: bool = False,
+ ):
+ if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
+ deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
+ deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0",
+ deprecation_message)
+
+ # settings for DPM-Solver
+ if algorithm_type not in [
+ "dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"
+ ]:
+ if algorithm_type == "deis":
+ self.register_to_config(algorithm_type="dpmsolver++")
+ else:
+ raise NotImplementedError(
+ f"{algorithm_type} is not implemented for {self.__class__}")
+
+ if solver_type not in ["midpoint", "heun"]:
+ if solver_type in ["logrho", "bh1", "bh2"]:
+ self.register_to_config(solver_type="midpoint")
+ else:
+ raise NotImplementedError(
+ f"{solver_type} is not implemented for {self.__class__}")
+
+ if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"
+ ] and final_sigmas_type == "zero":
+ raise ValueError(
+ f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead."
+ )
+
+ # setable values
+ self.num_inference_steps = None
+ alphas = np.linspace(1, 1 / num_train_timesteps,
+ num_train_timesteps)[::-1].copy()
+ sigmas = 1.0 - alphas
+ sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32)
+
+ if not use_dynamic_shifting:
+ # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
+ sigmas = shift * sigmas / (1 +
+ (shift - 1) * sigmas) # pyright: ignore
+
+ self.sigmas = sigmas
+ self.timesteps = sigmas * num_train_timesteps
+
+ self.model_outputs = [None] * solver_order
+ self.lower_order_nums = 0
+ self._step_index = None
+ self._begin_index = None
+
+ # self.sigmas = self.sigmas.to(
+ # "cpu") # to avoid too much CPU/GPU communication
+ self.sigma_min = self.sigmas[-1].item()
+ self.sigma_max = self.sigmas[0].item()
+
+ @property
+ def step_index(self):
+ """
+ The index counter for current timestep. It will increase 1 after each scheduler step.
+ """
+ return self._step_index
+
+ @property
+ def begin_index(self):
+ """
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
+ """
+ return self._begin_index
+
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
+ def set_begin_index(self, begin_index: int = 0):
+ """
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
+ Args:
+ begin_index (`int`):
+ The begin index for the scheduler.
+ """
+ self._begin_index = begin_index
+
+ # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps
+ def set_timesteps(
+ self,
+ num_inference_steps: Union[int, None] = None,
+ device: Union[str, torch.device] = None,
+ sigmas: Optional[List[float]] = None,
+ mu: Optional[Union[float, None]] = None,
+ shift: Optional[Union[float, None]] = None,
+ ):
+ """
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
+ Args:
+ num_inference_steps (`int`):
+ Total number of the spacing of the time steps.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ """
+
+ if self.config.use_dynamic_shifting and mu is None:
+ raise ValueError(
+ " you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`"
+ )
+
+ if sigmas is None:
+ sigmas = np.linspace(self.sigma_max, self.sigma_min,
+ num_inference_steps +
+ 1).copy()[:-1] # pyright: ignore
+
+ if self.config.use_dynamic_shifting:
+ sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore
+ else:
+ if shift is None:
+ shift = self.config.shift
+ sigmas = shift * sigmas / (1 +
+ (shift - 1) * sigmas) # pyright: ignore
+
+ if self.config.final_sigmas_type == "sigma_min":
+ sigma_last = ((1 - self.alphas_cumprod[0]) /
+ self.alphas_cumprod[0])**0.5
+ elif self.config.final_sigmas_type == "zero":
+ sigma_last = 0
+ else:
+ raise ValueError(
+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
+ )
+
+ timesteps = sigmas * self.config.num_train_timesteps
+ sigmas = np.concatenate([sigmas, [sigma_last]
+ ]).astype(np.float32) # pyright: ignore
+
+ self.sigmas = torch.from_numpy(sigmas)
+ self.timesteps = torch.from_numpy(timesteps).to(
+ device=device, dtype=torch.int64)
+
+ self.num_inference_steps = len(timesteps)
+
+ self.model_outputs = [
+ None,
+ ] * self.config.solver_order
+ self.lower_order_nums = 0
+
+ self._step_index = None
+ self._begin_index = None
+ # self.sigmas = self.sigmas.to(
+ # "cpu") # to avoid too much CPU/GPU communication
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
+ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
+ """
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
+ https://arxiv.org/abs/2205.11487
+ """
+ dtype = sample.dtype
+ batch_size, channels, *remaining_dims = sample.shape
+
+ if dtype not in (torch.float32, torch.float64):
+ sample = sample.float(
+ ) # upcast for quantile calculation, and clamp not implemented for cpu half
+
+ # Flatten sample for doing quantile calculation along each image
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
+
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
+
+ s = torch.quantile(
+ abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
+ s = torch.clamp(
+ s, min=1, max=self.config.sample_max_value
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
+ s = s.unsqueeze(
+ 1) # (batch_size, 1) because clamp will broadcast along dim=0
+ sample = torch.clamp(
+ sample, -s, s
+ ) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
+
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
+ sample = sample.to(dtype)
+
+ return sample
+
+ # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t
+ def _sigma_to_t(self, sigma):
+ return sigma * self.config.num_train_timesteps
+
+ def _sigma_to_alpha_sigma_t(self, sigma):
+ return 1 - sigma, sigma
+
+ # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps
+ def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma)
+
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output
+ def convert_model_output(
+ self,
+ model_output: torch.Tensor,
+ *args,
+ sample: torch.Tensor = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ """
+ Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
+ designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
+ integral of the data prediction model.
+
+ The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
+ prediction and data prediction models.
+
+ Args:
+ model_output (`torch.Tensor`):
+ The direct output from the learned diffusion model.
+ sample (`torch.Tensor`):
+ A current instance of a sample created by the diffusion process.
+ Returns:
+ `torch.Tensor`:
+ The converted model output.
+ """
+ timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
+ if sample is None:
+ if len(args) > 1:
+ sample = args[1]
+ else:
+ raise ValueError(
+ "missing `sample` as a required keyward argument")
+ if timestep is not None:
+ deprecate(
+ "timesteps",
+ "1.0.0",
+ "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
+ )
+
+ # DPM-Solver++ needs to solve an integral of the data prediction model.
+ if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
+ if self.config.prediction_type == "flow_prediction":
+ sigma_t = self.sigmas[self.step_index]
+ x0_pred = sample - sigma_t * model_output
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
+ " `v_prediction`, or `flow_prediction` for the FlowDPMSolverMultistepScheduler."
+ )
+
+ if self.config.thresholding:
+ x0_pred = self._threshold_sample(x0_pred)
+
+ return x0_pred
+
+ # DPM-Solver needs to solve an integral of the noise prediction model.
+ elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
+ if self.config.prediction_type == "flow_prediction":
+ sigma_t = self.sigmas[self.step_index]
+ epsilon = sample - (1 - sigma_t) * model_output
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
+ " `v_prediction` or `flow_prediction` for the FlowDPMSolverMultistepScheduler."
+ )
+
+ if self.config.thresholding:
+ sigma_t = self.sigmas[self.step_index]
+ x0_pred = sample - sigma_t * model_output
+ x0_pred = self._threshold_sample(x0_pred)
+ epsilon = model_output + x0_pred
+
+ return epsilon
+
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.dpm_solver_first_order_update
+ def dpm_solver_first_order_update(
+ self,
+ model_output: torch.Tensor,
+ *args,
+ sample: torch.Tensor = None,
+ noise: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ """
+ One step for the first-order DPMSolver (equivalent to DDIM).
+ Args:
+ model_output (`torch.Tensor`):
+ The direct output from the learned diffusion model.
+ sample (`torch.Tensor`):
+ A current instance of a sample created by the diffusion process.
+ Returns:
+ `torch.Tensor`:
+ The sample tensor at the previous timestep.
+ """
+ timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
+ prev_timestep = args[1] if len(args) > 1 else kwargs.pop(
+ "prev_timestep", None)
+ if sample is None:
+ if len(args) > 2:
+ sample = args[2]
+ else:
+ raise ValueError(
+ " missing `sample` as a required keyward argument")
+ if timestep is not None:
+ deprecate(
+ "timesteps",
+ "1.0.0",
+ "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
+ )
+
+ if prev_timestep is not None:
+ deprecate(
+ "prev_timestep",
+ "1.0.0",
+ "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
+ )
+
+ sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[
+ self.step_index] # pyright: ignore
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
+ alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
+ lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
+
+ h = lambda_t - lambda_s
+ if self.config.algorithm_type == "dpmsolver++":
+ x_t = (sigma_t /
+ sigma_s) * sample - (alpha_t *
+ (torch.exp(-h) - 1.0)) * model_output
+ elif self.config.algorithm_type == "dpmsolver":
+ x_t = (alpha_t /
+ alpha_s) * sample - (sigma_t *
+ (torch.exp(h) - 1.0)) * model_output
+ elif self.config.algorithm_type == "sde-dpmsolver++":
+ assert noise is not None
+ x_t = ((sigma_t / sigma_s * torch.exp(-h)) * sample +
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output +
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise)
+ elif self.config.algorithm_type == "sde-dpmsolver":
+ assert noise is not None
+ x_t = ((alpha_t / alpha_s) * sample - 2.0 *
+ (sigma_t * (torch.exp(h) - 1.0)) * model_output +
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise)
+ return x_t # pyright: ignore
+
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_second_order_update
+ def multistep_dpm_solver_second_order_update(
+ self,
+ model_output_list: List[torch.Tensor],
+ *args,
+ sample: torch.Tensor = None,
+ noise: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ """
+ One step for the second-order multistep DPMSolver.
+ Args:
+ model_output_list (`List[torch.Tensor]`):
+ The direct outputs from learned diffusion model at current and latter timesteps.
+ sample (`torch.Tensor`):
+ A current instance of a sample created by the diffusion process.
+ Returns:
+ `torch.Tensor`:
+ The sample tensor at the previous timestep.
+ """
+ timestep_list = args[0] if len(args) > 0 else kwargs.pop(
+ "timestep_list", None)
+ prev_timestep = args[1] if len(args) > 1 else kwargs.pop(
+ "prev_timestep", None)
+ if sample is None:
+ if len(args) > 2:
+ sample = args[2]
+ else:
+ raise ValueError(
+ " missing `sample` as a required keyward argument")
+ if timestep_list is not None:
+ deprecate(
+ "timestep_list",
+ "1.0.0",
+ "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
+ )
+
+ if prev_timestep is not None:
+ deprecate(
+ "prev_timestep",
+ "1.0.0",
+ "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
+ )
+
+ sigma_t, sigma_s0, sigma_s1 = (
+ self.sigmas[self.step_index + 1], # pyright: ignore
+ self.sigmas[self.step_index],
+ self.sigmas[self.step_index - 1], # pyright: ignore
+ )
+
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
+ alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
+
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
+ lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
+
+ m0, m1 = model_output_list[-1], model_output_list[-2]
+
+ h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
+ r0 = h_0 / h
+ D0, D1 = m0, (1.0 / r0) * (m0 - m1)
+ if self.config.algorithm_type == "dpmsolver++":
+ # See https://arxiv.org/abs/2211.01095 for detailed derivations
+ if self.config.solver_type == "midpoint":
+ x_t = ((sigma_t / sigma_s0) * sample -
+ (alpha_t * (torch.exp(-h) - 1.0)) * D0 - 0.5 *
+ (alpha_t * (torch.exp(-h) - 1.0)) * D1)
+ elif self.config.solver_type == "heun":
+ x_t = ((sigma_t / sigma_s0) * sample -
+ (alpha_t * (torch.exp(-h) - 1.0)) * D0 +
+ (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1)
+ elif self.config.algorithm_type == "dpmsolver":
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
+ if self.config.solver_type == "midpoint":
+ x_t = ((alpha_t / alpha_s0) * sample -
+ (sigma_t * (torch.exp(h) - 1.0)) * D0 - 0.5 *
+ (sigma_t * (torch.exp(h) - 1.0)) * D1)
+ elif self.config.solver_type == "heun":
+ x_t = ((alpha_t / alpha_s0) * sample -
+ (sigma_t * (torch.exp(h) - 1.0)) * D0 -
+ (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1)
+ elif self.config.algorithm_type == "sde-dpmsolver++":
+ assert noise is not None
+ if self.config.solver_type == "midpoint":
+ x_t = ((sigma_t / sigma_s0 * torch.exp(-h)) * sample +
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + 0.5 *
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 +
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise)
+ elif self.config.solver_type == "heun":
+ x_t = ((sigma_t / sigma_s0 * torch.exp(-h)) * sample +
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 +
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) /
+ (-2.0 * h) + 1.0)) * D1 +
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise)
+ elif self.config.algorithm_type == "sde-dpmsolver":
+ assert noise is not None
+ if self.config.solver_type == "midpoint":
+ x_t = ((alpha_t / alpha_s0) * sample - 2.0 *
+ (sigma_t * (torch.exp(h) - 1.0)) * D0 -
+ (sigma_t * (torch.exp(h) - 1.0)) * D1 +
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise)
+ elif self.config.solver_type == "heun":
+ x_t = ((alpha_t / alpha_s0) * sample - 2.0 *
+ (sigma_t * (torch.exp(h) - 1.0)) * D0 - 2.0 *
+ (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 +
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise)
+ return x_t # pyright: ignore
+
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_third_order_update
+ def multistep_dpm_solver_third_order_update(
+ self,
+ model_output_list: List[torch.Tensor],
+ *args,
+ sample: torch.Tensor = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ """
+ One step for the third-order multistep DPMSolver.
+ Args:
+ model_output_list (`List[torch.Tensor]`):
+ The direct outputs from learned diffusion model at current and latter timesteps.
+ sample (`torch.Tensor`):
+ A current instance of a sample created by diffusion process.
+ Returns:
+ `torch.Tensor`:
+ The sample tensor at the previous timestep.
+ """
+
+ timestep_list = args[0] if len(args) > 0 else kwargs.pop(
+ "timestep_list", None)
+ prev_timestep = args[1] if len(args) > 1 else kwargs.pop(
+ "prev_timestep", None)
+ if sample is None:
+ if len(args) > 2:
+ sample = args[2]
+ else:
+ raise ValueError(
+ " missing`sample` as a required keyward argument")
+ if timestep_list is not None:
+ deprecate(
+ "timestep_list",
+ "1.0.0",
+ "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
+ )
+
+ if prev_timestep is not None:
+ deprecate(
+ "prev_timestep",
+ "1.0.0",
+ "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
+ )
+
+ sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
+ self.sigmas[self.step_index + 1], # pyright: ignore
+ self.sigmas[self.step_index],
+ self.sigmas[self.step_index - 1], # pyright: ignore
+ self.sigmas[self.step_index - 2], # pyright: ignore
+ )
+
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
+ alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
+ alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
+
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
+ lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
+ lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2)
+
+ m0, m1, m2 = model_output_list[-1], model_output_list[
+ -2], model_output_list[-3]
+
+ h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
+ r0, r1 = h_0 / h, h_1 / h
+ D0 = m0
+ D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)
+ D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
+ D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
+ if self.config.algorithm_type == "dpmsolver++":
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
+ x_t = ((sigma_t / sigma_s0) * sample -
+ (alpha_t * (torch.exp(-h) - 1.0)) * D0 +
+ (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 -
+ (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2)
+ elif self.config.algorithm_type == "dpmsolver":
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
+ x_t = ((alpha_t / alpha_s0) * sample - (sigma_t *
+ (torch.exp(h) - 1.0)) * D0 -
+ (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 -
+ (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2)
+ return x_t # pyright: ignore
+
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
+ if schedule_timesteps is None:
+ schedule_timesteps = self.timesteps
+
+ indices = (schedule_timesteps == timestep).nonzero()
+
+ # The sigma index that is taken for the **very** first `step`
+ # is always the second index (or the last index if there is only 1)
+ # This way we can ensure we don't accidentally skip a sigma in
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
+ pos = 1 if len(indices) > 1 else 0
+
+ return indices[pos].item()
+
+ def _init_step_index(self, timestep):
+ """
+ Initialize the step_index counter for the scheduler.
+ """
+
+ if self.begin_index is None:
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.to(self.timesteps.device)
+ self._step_index = self.index_for_timestep(timestep)
+ else:
+ self._step_index = self._begin_index
+
+ # Modified from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.step
+ def step(
+ self,
+ model_output: torch.Tensor,
+ timestep: Union[int, torch.Tensor],
+ sample: torch.Tensor,
+ generator=None,
+ variance_noise: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ ) -> Union[SchedulerOutput, Tuple]:
+ """
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
+ the multistep DPMSolver.
+ Args:
+ model_output (`torch.Tensor`):
+ The direct output from learned diffusion model.
+ timestep (`int`):
+ The current discrete timestep in the diffusion chain.
+ sample (`torch.Tensor`):
+ A current instance of a sample created by the diffusion process.
+ generator (`torch.Generator`, *optional*):
+ A random number generator.
+ variance_noise (`torch.Tensor`):
+ Alternative to generating noise with `generator` by directly providing the noise for the variance
+ itself. Useful for methods such as [`LEdits++`].
+ return_dict (`bool`):
+ Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
+ Returns:
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
+ If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
+ tuple is returned where the first element is the sample tensor.
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ if self.step_index is None:
+ self._init_step_index(timestep)
+
+ # Improve numerical stability for small number of steps
+ lower_order_final = (self.step_index == len(self.timesteps) - 1) and (
+ self.config.euler_at_final or
+ (self.config.lower_order_final and len(self.timesteps) < 15) or
+ self.config.final_sigmas_type == "zero")
+ lower_order_second = ((self.step_index == len(self.timesteps) - 2) and
+ self.config.lower_order_final and
+ len(self.timesteps) < 15)
+
+ model_output = self.convert_model_output(model_output, sample=sample)
+ for i in range(self.config.solver_order - 1):
+ self.model_outputs[i] = self.model_outputs[i + 1]
+ self.model_outputs[-1] = model_output
+
+ # Upcast to avoid precision issues when computing prev_sample
+ sample = sample.to(torch.float32)
+ if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"
+ ] and variance_noise is None:
+ noise = randn_tensor(
+ model_output.shape,
+ generator=generator,
+ device=model_output.device,
+ dtype=torch.float32)
+ elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
+ noise = variance_noise.to(
+ device=model_output.device,
+ dtype=torch.float32) # pyright: ignore
+ else:
+ noise = None
+
+ if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
+ prev_sample = self.dpm_solver_first_order_update(
+ model_output, sample=sample, noise=noise)
+ elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
+ prev_sample = self.multistep_dpm_solver_second_order_update(
+ self.model_outputs, sample=sample, noise=noise)
+ else:
+ prev_sample = self.multistep_dpm_solver_third_order_update(
+ self.model_outputs, sample=sample)
+
+ if self.lower_order_nums < self.config.solver_order:
+ self.lower_order_nums += 1
+
+ # Cast sample back to expected dtype
+ prev_sample = prev_sample.to(model_output.dtype)
+
+ # upon completion increase step index by one
+ self._step_index += 1 # pyright: ignore
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return SchedulerOutput(prev_sample=prev_sample)
+
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input
+ def scale_model_input(self, sample: torch.Tensor, *args,
+ **kwargs) -> torch.Tensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+ Args:
+ sample (`torch.Tensor`):
+ The input sample.
+ Returns:
+ `torch.Tensor`:
+ A scaled input sample.
+ """
+ return sample
+
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input
+ def add_noise(
+ self,
+ original_samples: torch.Tensor,
+ noise: torch.Tensor,
+ timesteps: torch.IntTensor,
+ ) -> torch.Tensor:
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
+ sigmas = self.sigmas.to(
+ device=original_samples.device, dtype=original_samples.dtype)
+ if original_samples.device.type == "mps" and torch.is_floating_point(
+ timesteps):
+ # mps does not support float64
+ schedule_timesteps = self.timesteps.to(
+ original_samples.device, dtype=torch.float32)
+ timesteps = timesteps.to(
+ original_samples.device, dtype=torch.float32)
+ else:
+ schedule_timesteps = self.timesteps.to(original_samples.device)
+ timesteps = timesteps.to(original_samples.device)
+
+ # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
+ if self.begin_index is None:
+ step_indices = [
+ self.index_for_timestep(t, schedule_timesteps)
+ for t in timesteps
+ ]
+ elif self.step_index is not None:
+ # add_noise is called after first denoising step (for inpainting)
+ step_indices = [self.step_index] * timesteps.shape[0]
+ else:
+ # add noise is called before first denoising step to create initial latent(img2img)
+ step_indices = [self.begin_index] * timesteps.shape[0]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < len(original_samples.shape):
+ sigma = sigma.unsqueeze(-1)
+
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
+ noisy_samples = alpha_t * original_samples + sigma_t * noise
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/humo/models/utils/fm_solvers_unipc.py b/humo/models/utils/fm_solvers_unipc.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ed93733d74df502df2a9bf1dc6509c2193368b7
--- /dev/null
+++ b/humo/models/utils/fm_solvers_unipc.py
@@ -0,0 +1,800 @@
+# Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py
+# Convert unipc for flow matching
+# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
+
+import math
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.schedulers.scheduling_utils import (KarrasDiffusionSchedulers,
+ SchedulerMixin,
+ SchedulerOutput)
+from diffusers.utils import deprecate, is_scipy_available
+
+if is_scipy_available():
+ import scipy.stats
+
+
+class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
+ """
+ `UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models.
+
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
+ methods the library implements for all schedulers such as loading and saving.
+
+ Args:
+ num_train_timesteps (`int`, defaults to 1000):
+ The number of diffusion steps to train the model.
+ solver_order (`int`, default `2`):
+ The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1`
+ due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for
+ unconditional sampling.
+ prediction_type (`str`, defaults to "flow_prediction"):
+ Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts
+ the flow of the diffusion process.
+ thresholding (`bool`, defaults to `False`):
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
+ as Stable Diffusion.
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
+ sample_max_value (`float`, defaults to 1.0):
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`.
+ predict_x0 (`bool`, defaults to `True`):
+ Whether to use the updating algorithm on the predicted x0.
+ solver_type (`str`, default `bh2`):
+ Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2`
+ otherwise.
+ lower_order_final (`bool`, default `True`):
+ Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
+ stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
+ disable_corrector (`list`, default `[]`):
+ Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)`
+ and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is
+ usually disabled during the first few steps.
+ solver_p (`SchedulerMixin`, default `None`):
+ Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`.
+ use_karras_sigmas (`bool`, *optional*, defaults to `False`):
+ Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
+ the sigmas are determined according to a sequence of noise levels {σi}.
+ use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
+ Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
+ timestep_spacing (`str`, defaults to `"linspace"`):
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
+ steps_offset (`int`, defaults to 0):
+ An offset added to the inference steps, as required by some model families.
+ final_sigmas_type (`str`, defaults to `"zero"`):
+ The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
+ sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
+ """
+
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ solver_order: int = 2,
+ prediction_type: str = "flow_prediction",
+ shift: Optional[float] = 1.0,
+ use_dynamic_shifting=False,
+ thresholding: bool = False,
+ dynamic_thresholding_ratio: float = 0.995,
+ sample_max_value: float = 1.0,
+ predict_x0: bool = True,
+ solver_type: str = "bh2",
+ lower_order_final: bool = True,
+ disable_corrector: List[int] = [],
+ solver_p: SchedulerMixin = None,
+ timestep_spacing: str = "linspace",
+ steps_offset: int = 0,
+ final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
+ ):
+
+ if solver_type not in ["bh1", "bh2"]:
+ if solver_type in ["midpoint", "heun", "logrho"]:
+ self.register_to_config(solver_type="bh2")
+ else:
+ raise NotImplementedError(
+ f"{solver_type} is not implemented for {self.__class__}")
+
+ self.predict_x0 = predict_x0
+ # setable values
+ self.num_inference_steps = None
+ alphas = np.linspace(1, 1 / num_train_timesteps,
+ num_train_timesteps)[::-1].copy()
+ sigmas = 1.0 - alphas
+ sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32)
+
+ if not use_dynamic_shifting:
+ # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
+ sigmas = shift * sigmas / (1 +
+ (shift - 1) * sigmas) # pyright: ignore
+
+ self.sigmas = sigmas
+ self.timesteps = sigmas * num_train_timesteps
+
+ self.model_outputs = [None] * solver_order
+ self.timestep_list = [None] * solver_order
+ self.lower_order_nums = 0
+ self.disable_corrector = disable_corrector
+ self.solver_p = solver_p
+ self.last_sample = None
+ self._step_index = None
+ self._begin_index = None
+
+ self.sigmas = self.sigmas.to(
+ "cpu") # to avoid too much CPU/GPU communication
+ self.sigma_min = self.sigmas[-1].item()
+ self.sigma_max = self.sigmas[0].item()
+
+ @property
+ def step_index(self):
+ """
+ The index counter for current timestep. It will increase 1 after each scheduler step.
+ """
+ return self._step_index
+
+ @property
+ def begin_index(self):
+ """
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
+ """
+ return self._begin_index
+
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
+ def set_begin_index(self, begin_index: int = 0):
+ """
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
+
+ Args:
+ begin_index (`int`):
+ The begin index for the scheduler.
+ """
+ self._begin_index = begin_index
+
+ # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps
+ def set_timesteps(
+ self,
+ num_inference_steps: Union[int, None] = None,
+ device: Union[str, torch.device] = None,
+ sigmas: Optional[List[float]] = None,
+ mu: Optional[Union[float, None]] = None,
+ shift: Optional[Union[float, None]] = None,
+ ):
+ """
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
+ Args:
+ num_inference_steps (`int`):
+ Total number of the spacing of the time steps.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ """
+
+ if self.config.use_dynamic_shifting and mu is None:
+ raise ValueError(
+ " you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`"
+ )
+
+ if sigmas is None:
+ sigmas = np.linspace(self.sigma_max, self.sigma_min,
+ num_inference_steps +
+ 1).copy()[:-1] # pyright: ignore
+
+ if self.config.use_dynamic_shifting:
+ sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore
+ else:
+ if shift is None:
+ shift = self.config.shift
+ sigmas = shift * sigmas / (1 +
+ (shift - 1) * sigmas) # pyright: ignore
+
+ if self.config.final_sigmas_type == "sigma_min":
+ sigma_last = ((1 - self.alphas_cumprod[0]) /
+ self.alphas_cumprod[0])**0.5
+ elif self.config.final_sigmas_type == "zero":
+ sigma_last = 0
+ else:
+ raise ValueError(
+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
+ )
+
+ timesteps = sigmas * self.config.num_train_timesteps
+ sigmas = np.concatenate([sigmas, [sigma_last]
+ ]).astype(np.float32) # pyright: ignore
+
+ self.sigmas = torch.from_numpy(sigmas)
+ self.timesteps = torch.from_numpy(timesteps).to(
+ device=device, dtype=torch.int64)
+
+ self.num_inference_steps = len(timesteps)
+
+ self.model_outputs = [
+ None,
+ ] * self.config.solver_order
+ self.lower_order_nums = 0
+ self.last_sample = None
+ if self.solver_p:
+ self.solver_p.set_timesteps(self.num_inference_steps, device=device)
+
+ # add an index counter for schedulers that allow duplicated timesteps
+ self._step_index = None
+ self._begin_index = None
+ self.sigmas = self.sigmas.to(
+ "cpu") # to avoid too much CPU/GPU communication
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
+ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
+ """
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
+
+ https://arxiv.org/abs/2205.11487
+ """
+ dtype = sample.dtype
+ batch_size, channels, *remaining_dims = sample.shape
+
+ if dtype not in (torch.float32, torch.float64):
+ sample = sample.float(
+ ) # upcast for quantile calculation, and clamp not implemented for cpu half
+
+ # Flatten sample for doing quantile calculation along each image
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
+
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
+
+ s = torch.quantile(
+ abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
+ s = torch.clamp(
+ s, min=1, max=self.config.sample_max_value
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
+ s = s.unsqueeze(
+ 1) # (batch_size, 1) because clamp will broadcast along dim=0
+ sample = torch.clamp(
+ sample, -s, s
+ ) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
+
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
+ sample = sample.to(dtype)
+
+ return sample
+
+ # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t
+ def _sigma_to_t(self, sigma):
+ return sigma * self.config.num_train_timesteps
+
+ def _sigma_to_alpha_sigma_t(self, sigma):
+ return 1 - sigma, sigma
+
+ # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps
+ def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma)
+
+ def convert_model_output(
+ self,
+ model_output: torch.Tensor,
+ *args,
+ sample: torch.Tensor = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ r"""
+ Convert the model output to the corresponding type the UniPC algorithm needs.
+
+ Args:
+ model_output (`torch.Tensor`):
+ The direct output from the learned diffusion model.
+ timestep (`int`):
+ The current discrete timestep in the diffusion chain.
+ sample (`torch.Tensor`):
+ A current instance of a sample created by the diffusion process.
+
+ Returns:
+ `torch.Tensor`:
+ The converted model output.
+ """
+ timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
+ if sample is None:
+ if len(args) > 1:
+ sample = args[1]
+ else:
+ raise ValueError(
+ "missing `sample` as a required keyward argument")
+ if timestep is not None:
+ deprecate(
+ "timesteps",
+ "1.0.0",
+ "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
+ )
+
+ sigma = self.sigmas[self.step_index]
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
+
+ if self.predict_x0:
+ if self.config.prediction_type == "flow_prediction":
+ sigma_t = self.sigmas[self.step_index]
+ x0_pred = sample - sigma_t * model_output
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
+ " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler."
+ )
+
+ if self.config.thresholding:
+ x0_pred = self._threshold_sample(x0_pred)
+
+ return x0_pred
+ else:
+ if self.config.prediction_type == "flow_prediction":
+ sigma_t = self.sigmas[self.step_index]
+ epsilon = sample - (1 - sigma_t) * model_output
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
+ " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler."
+ )
+
+ if self.config.thresholding:
+ sigma_t = self.sigmas[self.step_index]
+ x0_pred = sample - sigma_t * model_output
+ x0_pred = self._threshold_sample(x0_pred)
+ epsilon = model_output + x0_pred
+
+ return epsilon
+
+ def multistep_uni_p_bh_update(
+ self,
+ model_output: torch.Tensor,
+ *args,
+ sample: torch.Tensor = None,
+ order: int = None, # pyright: ignore
+ **kwargs,
+ ) -> torch.Tensor:
+ """
+ One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified.
+
+ Args:
+ model_output (`torch.Tensor`):
+ The direct output from the learned diffusion model at the current timestep.
+ prev_timestep (`int`):
+ The previous discrete timestep in the diffusion chain.
+ sample (`torch.Tensor`):
+ A current instance of a sample created by the diffusion process.
+ order (`int`):
+ The order of UniP at this timestep (corresponds to the *p* in UniPC-p).
+
+ Returns:
+ `torch.Tensor`:
+ The sample tensor at the previous timestep.
+ """
+ prev_timestep = args[0] if len(args) > 0 else kwargs.pop(
+ "prev_timestep", None)
+ if sample is None:
+ if len(args) > 1:
+ sample = args[1]
+ else:
+ raise ValueError(
+ " missing `sample` as a required keyward argument")
+ if order is None:
+ if len(args) > 2:
+ order = args[2]
+ else:
+ raise ValueError(
+ " missing `order` as a required keyward argument")
+ if prev_timestep is not None:
+ deprecate(
+ "prev_timestep",
+ "1.0.0",
+ "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
+ )
+ model_output_list = self.model_outputs
+
+ s0 = self.timestep_list[-1]
+ m0 = model_output_list[-1]
+ x = sample
+
+ if self.solver_p:
+ x_t = self.solver_p.step(model_output, s0, x).prev_sample
+ return x_t
+
+ sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[
+ self.step_index] # pyright: ignore
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
+
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
+
+ h = lambda_t - lambda_s0
+ device = sample.device
+
+ rks = []
+ D1s = []
+ for i in range(1, order):
+ si = self.step_index - i # pyright: ignore
+ mi = model_output_list[-(i + 1)]
+ alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
+ lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
+ rk = (lambda_si - lambda_s0) / h
+ rks.append(rk)
+ D1s.append((mi - m0) / rk) # pyright: ignore
+
+ rks.append(1.0)
+ rks = torch.tensor(rks, device=device)
+
+ R = []
+ b = []
+
+ hh = -h if self.predict_x0 else h
+ h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
+ h_phi_k = h_phi_1 / hh - 1
+
+ factorial_i = 1
+
+ if self.config.solver_type == "bh1":
+ B_h = hh
+ elif self.config.solver_type == "bh2":
+ B_h = torch.expm1(hh)
+ else:
+ raise NotImplementedError()
+
+ for i in range(1, order + 1):
+ R.append(torch.pow(rks, i - 1))
+ b.append(h_phi_k * factorial_i / B_h)
+ factorial_i *= i + 1
+ h_phi_k = h_phi_k / hh - 1 / factorial_i
+
+ R = torch.stack(R)
+ b = torch.tensor(b, device=device)
+
+ if len(D1s) > 0:
+ D1s = torch.stack(D1s, dim=1) # (B, K)
+ # for order 2, we use a simplified version
+ if order == 2:
+ rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
+ else:
+ rhos_p = torch.linalg.solve(R[:-1, :-1],
+ b[:-1]).to(device).to(x.dtype)
+ else:
+ D1s = None
+
+ if self.predict_x0:
+ x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
+ if D1s is not None:
+ pred_res = torch.einsum("k,bkc...->bc...", rhos_p,
+ D1s) # pyright: ignore
+ else:
+ pred_res = 0
+ x_t = x_t_ - alpha_t * B_h * pred_res
+ else:
+ x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
+ if D1s is not None:
+ pred_res = torch.einsum("k,bkc...->bc...", rhos_p,
+ D1s) # pyright: ignore
+ else:
+ pred_res = 0
+ x_t = x_t_ - sigma_t * B_h * pred_res
+
+ x_t = x_t.to(x.dtype)
+ return x_t
+
+ def multistep_uni_c_bh_update(
+ self,
+ this_model_output: torch.Tensor,
+ *args,
+ last_sample: torch.Tensor = None,
+ this_sample: torch.Tensor = None,
+ order: int = None, # pyright: ignore
+ **kwargs,
+ ) -> torch.Tensor:
+ """
+ One step for the UniC (B(h) version).
+
+ Args:
+ this_model_output (`torch.Tensor`):
+ The model outputs at `x_t`.
+ this_timestep (`int`):
+ The current timestep `t`.
+ last_sample (`torch.Tensor`):
+ The generated sample before the last predictor `x_{t-1}`.
+ this_sample (`torch.Tensor`):
+ The generated sample after the last predictor `x_{t}`.
+ order (`int`):
+ The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`.
+
+ Returns:
+ `torch.Tensor`:
+ The corrected sample tensor at the current timestep.
+ """
+ this_timestep = args[0] if len(args) > 0 else kwargs.pop(
+ "this_timestep", None)
+ if last_sample is None:
+ if len(args) > 1:
+ last_sample = args[1]
+ else:
+ raise ValueError(
+ " missing`last_sample` as a required keyward argument")
+ if this_sample is None:
+ if len(args) > 2:
+ this_sample = args[2]
+ else:
+ raise ValueError(
+ " missing`this_sample` as a required keyward argument")
+ if order is None:
+ if len(args) > 3:
+ order = args[3]
+ else:
+ raise ValueError(
+ " missing`order` as a required keyward argument")
+ if this_timestep is not None:
+ deprecate(
+ "this_timestep",
+ "1.0.0",
+ "Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
+ )
+
+ model_output_list = self.model_outputs
+
+ m0 = model_output_list[-1]
+ x = last_sample
+ x_t = this_sample
+ model_t = this_model_output
+
+ sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[
+ self.step_index - 1] # pyright: ignore
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
+
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
+
+ h = lambda_t - lambda_s0
+ device = this_sample.device
+
+ rks = []
+ D1s = []
+ for i in range(1, order):
+ si = self.step_index - (i + 1) # pyright: ignore
+ mi = model_output_list[-(i + 1)]
+ alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
+ lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
+ rk = (lambda_si - lambda_s0) / h
+ rks.append(rk)
+ D1s.append((mi - m0) / rk) # pyright: ignore
+
+ rks.append(1.0)
+ rks = torch.tensor(rks, device=device)
+
+ R = []
+ b = []
+
+ hh = -h if self.predict_x0 else h
+ h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
+ h_phi_k = h_phi_1 / hh - 1
+
+ factorial_i = 1
+
+ if self.config.solver_type == "bh1":
+ B_h = hh
+ elif self.config.solver_type == "bh2":
+ B_h = torch.expm1(hh)
+ else:
+ raise NotImplementedError()
+
+ for i in range(1, order + 1):
+ R.append(torch.pow(rks, i - 1))
+ b.append(h_phi_k * factorial_i / B_h)
+ factorial_i *= i + 1
+ h_phi_k = h_phi_k / hh - 1 / factorial_i
+
+ R = torch.stack(R)
+ b = torch.tensor(b, device=device)
+
+ if len(D1s) > 0:
+ D1s = torch.stack(D1s, dim=1)
+ else:
+ D1s = None
+
+ # for order 1, we use a simplified version
+ if order == 1:
+ rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
+ else:
+ rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype)
+
+ if self.predict_x0:
+ x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
+ if D1s is not None:
+ corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
+ else:
+ corr_res = 0
+ D1_t = model_t - m0
+ x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)
+ else:
+ x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
+ if D1s is not None:
+ corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
+ else:
+ corr_res = 0
+ D1_t = model_t - m0
+ x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t)
+ x_t = x_t.to(x.dtype)
+ return x_t
+
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
+ if schedule_timesteps is None:
+ schedule_timesteps = self.timesteps
+
+ indices = (schedule_timesteps == timestep).nonzero()
+
+ # The sigma index that is taken for the **very** first `step`
+ # is always the second index (or the last index if there is only 1)
+ # This way we can ensure we don't accidentally skip a sigma in
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
+ pos = 1 if len(indices) > 1 else 0
+
+ return indices[pos].item()
+
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
+ def _init_step_index(self, timestep):
+ """
+ Initialize the step_index counter for the scheduler.
+ """
+
+ if self.begin_index is None:
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.to(self.timesteps.device)
+ self._step_index = self.index_for_timestep(timestep)
+ else:
+ self._step_index = self._begin_index
+
+ def step(self,
+ model_output: torch.Tensor,
+ timestep: Union[int, torch.Tensor],
+ sample: torch.Tensor,
+ return_dict: bool = True,
+ generator=None) -> Union[SchedulerOutput, Tuple]:
+ """
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
+ the multistep UniPC.
+
+ Args:
+ model_output (`torch.Tensor`):
+ The direct output from learned diffusion model.
+ timestep (`int`):
+ The current discrete timestep in the diffusion chain.
+ sample (`torch.Tensor`):
+ A current instance of a sample created by the diffusion process.
+ return_dict (`bool`):
+ Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
+
+ Returns:
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
+ If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
+ tuple is returned where the first element is the sample tensor.
+
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ if self.step_index is None:
+ self._init_step_index(timestep)
+
+ use_corrector = (
+ self.step_index > 0 and
+ self.step_index - 1 not in self.disable_corrector and
+ self.last_sample is not None # pyright: ignore
+ )
+
+ model_output_convert = self.convert_model_output(
+ model_output, sample=sample)
+ if use_corrector:
+ sample = self.multistep_uni_c_bh_update(
+ this_model_output=model_output_convert,
+ last_sample=self.last_sample,
+ this_sample=sample,
+ order=self.this_order,
+ )
+
+ for i in range(self.config.solver_order - 1):
+ self.model_outputs[i] = self.model_outputs[i + 1]
+ self.timestep_list[i] = self.timestep_list[i + 1]
+
+ self.model_outputs[-1] = model_output_convert
+ self.timestep_list[-1] = timestep # pyright: ignore
+
+ if self.config.lower_order_final:
+ this_order = min(self.config.solver_order,
+ len(self.timesteps) -
+ self.step_index) # pyright: ignore
+ else:
+ this_order = self.config.solver_order
+
+ self.this_order = min(this_order,
+ self.lower_order_nums + 1) # warmup for multistep
+ assert self.this_order > 0
+
+ self.last_sample = sample
+ prev_sample = self.multistep_uni_p_bh_update(
+ model_output=model_output, # pass the original non-converted model output, in case solver-p is used
+ sample=sample,
+ order=self.this_order,
+ )
+
+ if self.lower_order_nums < self.config.solver_order:
+ self.lower_order_nums += 1
+
+ # upon completion increase step index by one
+ self._step_index += 1 # pyright: ignore
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return SchedulerOutput(prev_sample=prev_sample)
+
+ def scale_model_input(self, sample: torch.Tensor, *args,
+ **kwargs) -> torch.Tensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+
+ Args:
+ sample (`torch.Tensor`):
+ The input sample.
+
+ Returns:
+ `torch.Tensor`:
+ A scaled input sample.
+ """
+ return sample
+
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
+ def add_noise(
+ self,
+ original_samples: torch.Tensor,
+ noise: torch.Tensor,
+ timesteps: torch.IntTensor,
+ ) -> torch.Tensor:
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
+ sigmas = self.sigmas.to(
+ device=original_samples.device, dtype=original_samples.dtype)
+ if original_samples.device.type == "mps" and torch.is_floating_point(
+ timesteps):
+ # mps does not support float64
+ schedule_timesteps = self.timesteps.to(
+ original_samples.device, dtype=torch.float32)
+ timesteps = timesteps.to(
+ original_samples.device, dtype=torch.float32)
+ else:
+ schedule_timesteps = self.timesteps.to(original_samples.device)
+ timesteps = timesteps.to(original_samples.device)
+
+ # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
+ if self.begin_index is None:
+ step_indices = [
+ self.index_for_timestep(t, schedule_timesteps)
+ for t in timesteps
+ ]
+ elif self.step_index is not None:
+ # add_noise is called after first denoising step (for inpainting)
+ step_indices = [self.step_index] * timesteps.shape[0]
+ else:
+ # add noise is called before first denoising step to create initial latent(img2img)
+ step_indices = [self.begin_index] * timesteps.shape[0]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < len(original_samples.shape):
+ sigma = sigma.unsqueeze(-1)
+
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
+ noisy_samples = alpha_t * original_samples + sigma_t * noise
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/humo/models/utils/utils.py b/humo/models/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b30b2e95c73637ae87cb73452173e0592dba34c
--- /dev/null
+++ b/humo/models/utils/utils.py
@@ -0,0 +1,58 @@
+# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
+import argparse
+import binascii
+import os
+import os.path as osp
+import json
+from omegaconf import OmegaConf
+
+import imageio
+import torch
+import torchvision
+from moviepy.editor import AudioFileClip, VideoClip
+
+__all__ = ['tensor_to_video', 'prepare_json_dataset']
+
+
+def tensor_to_video(tensor, output_video_path, input_audio_path, fps=25):
+ """
+ Converts a Tensor with shape [c, f, h, w] into a video and adds an audio track from the specified audio file.
+
+ Args:
+ tensor (numpy): The Tensor to be converted, shaped [f, h, w, c].
+ output_video_path (str): The file path where the output video will be saved.
+ input_audio_path (str): The path to the audio file (WAV file) that contains the audio track to be added.
+ fps (int): The frame rate of the output video. Default is 30 fps.
+ """
+ def make_frame(t):
+ frame_index = min(int(t * fps), tensor.shape[0] - 1)
+ return tensor[frame_index]
+
+ video_duration = tensor.shape[0] / fps
+ audio_clip = AudioFileClip(input_audio_path)
+ audio_duration = audio_clip.duration
+ final_duration = min(video_duration, audio_duration)
+ audio_clip = audio_clip.subclip(0, final_duration)
+ new_video_clip = VideoClip(make_frame, duration=final_duration)
+ new_video_clip = new_video_clip.set_audio(audio_clip)
+ new_video_clip.write_videofile(output_video_path, fps=fps, audio_codec="aac")
+
+
+def prepare_json_dataset(json_path):
+ samples = []
+ with open(json_path, "rb") as f:
+ data = json.load(f)
+ for itemname, row in data.items():
+ text = row['prompt'].strip().replace("_", " ").strip('"')
+ audio_path = row['audio_path']
+ ref_img_path = [x for x in row['img_paths']]
+
+ samples.append({
+ "text": text,
+ "ref_img": ref_img_path,
+ "audio": audio_path,
+ "itemname": itemname
+ })
+ samples = OmegaConf.create(samples)
+
+ return samples
\ No newline at end of file
diff --git a/humo/models/wan_modules/__init__.py b/humo/models/wan_modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c7e79b0864945199247cabc9b48138e411b3216
--- /dev/null
+++ b/humo/models/wan_modules/__init__.py
@@ -0,0 +1,16 @@
+from .attention import flash_attention
+from .model import WanModel
+from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model
+from .tokenizers import HuggingfaceTokenizer
+from .vae import WanVAE
+
+__all__ = [
+ 'WanVAE',
+ 'WanModel',
+ 'T5Model',
+ 'T5Encoder',
+ 'T5Decoder',
+ 'T5EncoderModel',
+ 'HuggingfaceTokenizer',
+ 'flash_attention',
+]
diff --git a/humo/models/wan_modules/attention.py b/humo/models/wan_modules/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..76d26535e1d7e106dc8889e0a64549ebc2aad600
--- /dev/null
+++ b/humo/models/wan_modules/attention.py
@@ -0,0 +1,256 @@
+# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
+import warnings
+import torch
+from typing import Optional, Tuple
+
+try:
+ import flash_attn_interface
+ FLASH_ATTN_3_AVAILABLE = True
+except ModuleNotFoundError:
+ FLASH_ATTN_3_AVAILABLE = False
+
+try:
+ import flash_attn
+ FLASH_ATTN_2_AVAILABLE = True
+except ModuleNotFoundError:
+ FLASH_ATTN_2_AVAILABLE = False
+
+
+__all__ = [
+ 'flash_attention',
+ 'attention',
+]
+
+
+# ---------------------------
+# Custom op + fake kernel
+# ---------------------------
+from typing import Optional, Sequence # <- add Sequence
+
+# ... imports unchanged ...
+from typing import Optional, Sequence
+
+@torch.library.custom_op("wan::flash_attention", mutates_args=())
+def _wan_flash_attention_op(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ q_lens: Optional[torch.Tensor] = None,
+ k_lens: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ softmax_scale: Optional[float] = None,
+ q_scale: Optional[float] = None,
+ causal: bool = False,
+ # IMPORTANT: schema-friendly default (None), not a tuple
+ window_size: Optional[Sequence[int]] = None,
+ deterministic: bool = False,
+ dtype: torch.dtype = torch.bfloat16,
+ version: Optional[int] = None,
+) -> torch.Tensor:
+ half_dtypes = (torch.float16, torch.bfloat16)
+ assert dtype in half_dtypes
+ assert q.size(-1) <= 256
+
+ # normalize window_size to a 2-tuple for FA2 API
+ if window_size is None:
+ ws = (-1, -1)
+ else:
+ ws = tuple(window_size)
+ if len(ws) != 2:
+ raise ValueError(f"window_size must have length 2; got {window_size!r}")
+
+ b, lq, nheads = q.shape[0], q.shape[1], q.shape[2]
+ lk = k.shape[1]
+ out_dtype = q.dtype
+
+ def half(x: torch.Tensor) -> torch.Tensor:
+ return x if x.dtype in half_dtypes else x.to(dtype)
+
+ # --- preprocess (unchanged) ---
+ if q_lens is None:
+ q_flat = half(q.flatten(0, 1))
+ q_lens = torch.tensor([lq] * b, dtype=torch.int32)
+ else:
+ q_flat = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
+
+ if k_lens is None:
+ k_flat = half(k.flatten(0, 1))
+ v_flat = half(v.flatten(0, 1))
+ k_lens = torch.tensor([lk] * b, dtype=torch.int32)
+ else:
+ k_flat = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
+ v_flat = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
+
+ q_flat = q_flat.to(v_flat.dtype); k_flat = k_flat.to(v_flat.dtype)
+ if q_scale is not None:
+ q_flat = q_flat * q_scale
+
+ if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
+ warnings.warn('Flash attention 3 is not available, use flash attention 2 instead.')
+
+ if FLASH_ATTN_3_AVAILABLE:
+ ret = flash_attn_interface.flash_attn_varlen_func(
+ q=q_flat,
+ k=k_flat,
+ v=v_flat,
+ cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32).to(q_flat.device, non_blocking=True),
+ cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32).to(k_flat.device, non_blocking=True),
+ seqused_q=None,
+ seqused_k=None,
+ max_seqlen_q=lq,
+ max_seqlen_k=lk,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ deterministic=deterministic,
+ )
+ out0 = ret[0] if isinstance(ret, (tuple, list)) else ret
+ total_q = b * lq
+ if out0.dim() != 3:
+ raise RuntimeError(f"Unexpected FA3 output rank {out0.dim()} shape={tuple(out0.shape)}")
+ if out0.shape[0] == total_q:
+ out_flat = out0
+ elif out0.shape[0] == nheads and out0.shape[1] == total_q:
+ out_flat = out0.transpose(0, 1).contiguous()
+ else:
+ raise RuntimeError(f"Unexpected FA3 output shape {tuple(out0.shape)}")
+ out = out_flat.unflatten(0, (b, lq))
+
+ elif FLASH_ATTN_2_AVAILABLE:
+ out = flash_attn.flash_attn_varlen_func(
+ q=q_flat,
+ k=k_flat,
+ v=v_flat,
+ cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32).to(q_flat.device, non_blocking=True),
+ cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32).to(q_flat.device, non_blocking=True),
+ max_seqlen_q=lq,
+ max_seqlen_k=lk,
+ dropout_p=dropout_p,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ window_size=ws, # <- pass 2-tuple
+ deterministic=deterministic,
+ ).unflatten(0, (b, lq))
+ else:
+ q_s = q.transpose(1, 2).to(dtype)
+ k_s = k.transpose(1, 2).to(dtype)
+ v_s = v.transpose(1, 2).to(dtype)
+ out = torch.nn.functional.scaled_dot_product_attention(
+ q_s, k_s, v_s, attn_mask=None, is_causal=causal, dropout_p=dropout_p
+ ).transpose(1, 2).contiguous()
+
+ return out.to(out_dtype)
+
+@_wan_flash_attention_op.register_fake
+def _wan_flash_attention_op_fake(
+ q,
+ k,
+ v,
+ q_lens=None,
+ k_lens=None,
+ dropout_p: float = 0.0,
+ softmax_scale=None,
+ q_scale=None,
+ causal: bool = False,
+ window_size: Optional[Sequence[int]] = None,
+ deterministic: bool = False,
+ dtype: torch.dtype = torch.bfloat16,
+ version: Optional[int] = None,
+):
+ # Match output shape: (B, Lq, Nq, Dh_v) and keep the SAME fake device as `q`
+ B, Lq, Nq, _ = q.shape
+ Dh_v = v.shape[-1]
+ return q.new_empty((B, Lq, Nq, Dh_v), dtype=q.dtype)
+
+
+
+# ---------------------------
+# Public API (unchanged signature)
+# ---------------------------
+def flash_attention(
+ q,
+ k,
+ v,
+ q_lens=None,
+ k_lens=None,
+ dropout_p=0.,
+ softmax_scale=None,
+ q_scale=None,
+ causal=False,
+ window_size=(-1, -1),
+ deterministic=False,
+ dtype=torch.bfloat16,
+ version=None,
+):
+ """
+ q: [B, Lq, Nq, C1].
+ k: [B, Lk, Nk, C1].
+ v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
+ q_lens: [B].
+ k_lens: [B].
+ dropout_p: float. Dropout probability.
+ softmax_scale: float. The scaling of QK^T before applying softmax.
+ causal: bool. Whether to apply causal attention mask.
+ window_size: (left right). If not (-1, -1), apply sliding window local attention.
+ deterministic: bool. If True, slightly slower and uses more memory.
+ dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
+ """
+ # Simply delegate to the custom op so Dynamo/AOT treats it as a single node;
+ # our eager kernel inside _wan_flash_attention_op keeps the original behavior.
+ return _wan_flash_attention_op(
+ q, k, v,
+ q_lens=q_lens,
+ k_lens=k_lens,
+ dropout_p=dropout_p,
+ softmax_scale=softmax_scale,
+ q_scale=q_scale,
+ causal=causal,
+ window_size=window_size,
+ deterministic=deterministic,
+ dtype=dtype,
+ version=version,
+ )
+
+
+def attention(
+ q,
+ k,
+ v,
+ q_lens=None,
+ k_lens=None,
+ dropout_p=0.,
+ softmax_scale=None,
+ q_scale=None,
+ causal=False,
+ window_size=(-1, -1),
+ deterministic=False,
+ dtype=torch.bfloat16,
+ fa_version=None,
+):
+ if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
+ return flash_attention(
+ q=q,
+ k=k,
+ v=v,
+ q_lens=q_lens,
+ k_lens=k_lens,
+ dropout_p=dropout_p,
+ softmax_scale=softmax_scale,
+ q_scale=q_scale,
+ causal=causal,
+ window_size=window_size,
+ deterministic=deterministic,
+ dtype=dtype,
+ version=fa_version,
+ )
+ else:
+ if q_lens is not None or k_lens is not None:
+ warnings.warn(
+ 'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.'
+ )
+ q_ = q.transpose(1, 2).to(dtype)
+ k_ = k.transpose(1, 2).to(dtype)
+ v_ = v.transpose(1, 2).to(dtype)
+ out = torch.nn.functional.scaled_dot_product_attention(
+ q_, k_, v_, attn_mask=None, is_causal=causal, dropout_p=dropout_p
+ )
+ return out.transpose(1, 2).contiguous()
diff --git a/humo/models/wan_modules/clip.py b/humo/models/wan_modules/clip.py
new file mode 100644
index 0000000000000000000000000000000000000000..8659c4f13ce4da36110edb49a0ac71ccd5e3f841
--- /dev/null
+++ b/humo/models/wan_modules/clip.py
@@ -0,0 +1,542 @@
+# Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip''
+# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
+import logging
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision.transforms as T
+
+from .attention import flash_attention
+from .tokenizers import HuggingfaceTokenizer
+from .xlm_roberta import XLMRoberta
+
+__all__ = [
+ 'XLMRobertaCLIP',
+ 'clip_xlm_roberta_vit_h_14',
+ 'CLIPModel',
+]
+
+
+def pos_interpolate(pos, seq_len):
+ if pos.size(1) == seq_len:
+ return pos
+ else:
+ src_grid = int(math.sqrt(pos.size(1)))
+ tar_grid = int(math.sqrt(seq_len))
+ n = pos.size(1) - src_grid * src_grid
+ return torch.cat([
+ pos[:, :n],
+ F.interpolate(
+ pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(
+ 0, 3, 1, 2),
+ size=(tar_grid, tar_grid),
+ mode='bicubic',
+ align_corners=False).flatten(2).transpose(1, 2)
+ ],
+ dim=1)
+
+
+class QuickGELU(nn.Module):
+
+ def forward(self, x):
+ return x * torch.sigmoid(1.702 * x)
+
+
+class LayerNorm(nn.LayerNorm):
+
+ def forward(self, x):
+ return super().forward(x.float()).type_as(x)
+
+
+class SelfAttention(nn.Module):
+
+ def __init__(self,
+ dim,
+ num_heads,
+ causal=False,
+ attn_dropout=0.0,
+ proj_dropout=0.0):
+ assert dim % num_heads == 0
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.causal = causal
+ self.attn_dropout = attn_dropout
+ self.proj_dropout = proj_dropout
+
+ # layers
+ self.to_qkv = nn.Linear(dim, dim * 3)
+ self.proj = nn.Linear(dim, dim)
+
+ def forward(self, x):
+ """
+ x: [B, L, C].
+ """
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
+
+ # compute query, key, value
+ q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
+
+ # compute attention
+ p = self.attn_dropout if self.training else 0.0
+ x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2)
+ x = x.reshape(b, s, c)
+
+ # output
+ x = self.proj(x)
+ x = F.dropout(x, self.proj_dropout, self.training)
+ return x
+
+
+class SwiGLU(nn.Module):
+
+ def __init__(self, dim, mid_dim):
+ super().__init__()
+ self.dim = dim
+ self.mid_dim = mid_dim
+
+ # layers
+ self.fc1 = nn.Linear(dim, mid_dim)
+ self.fc2 = nn.Linear(dim, mid_dim)
+ self.fc3 = nn.Linear(mid_dim, dim)
+
+ def forward(self, x):
+ x = F.silu(self.fc1(x)) * self.fc2(x)
+ x = self.fc3(x)
+ return x
+
+
+class AttentionBlock(nn.Module):
+
+ def __init__(self,
+ dim,
+ mlp_ratio,
+ num_heads,
+ post_norm=False,
+ causal=False,
+ activation='quick_gelu',
+ attn_dropout=0.0,
+ proj_dropout=0.0,
+ norm_eps=1e-5):
+ assert activation in ['quick_gelu', 'gelu', 'swi_glu']
+ super().__init__()
+ self.dim = dim
+ self.mlp_ratio = mlp_ratio
+ self.num_heads = num_heads
+ self.post_norm = post_norm
+ self.causal = causal
+ self.norm_eps = norm_eps
+
+ # layers
+ self.norm1 = LayerNorm(dim, eps=norm_eps)
+ self.attn = SelfAttention(dim, num_heads, causal, attn_dropout,
+ proj_dropout)
+ self.norm2 = LayerNorm(dim, eps=norm_eps)
+ if activation == 'swi_glu':
+ self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
+ else:
+ self.mlp = nn.Sequential(
+ nn.Linear(dim, int(dim * mlp_ratio)),
+ QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
+ nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
+
+ def forward(self, x):
+ if self.post_norm:
+ x = x + self.norm1(self.attn(x))
+ x = x + self.norm2(self.mlp(x))
+ else:
+ x = x + self.attn(self.norm1(x))
+ x = x + self.mlp(self.norm2(x))
+ return x
+
+
+class AttentionPool(nn.Module):
+
+ def __init__(self,
+ dim,
+ mlp_ratio,
+ num_heads,
+ activation='gelu',
+ proj_dropout=0.0,
+ norm_eps=1e-5):
+ assert dim % num_heads == 0
+ super().__init__()
+ self.dim = dim
+ self.mlp_ratio = mlp_ratio
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.proj_dropout = proj_dropout
+ self.norm_eps = norm_eps
+
+ # layers
+ gain = 1.0 / math.sqrt(dim)
+ self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
+ self.to_q = nn.Linear(dim, dim)
+ self.to_kv = nn.Linear(dim, dim * 2)
+ self.proj = nn.Linear(dim, dim)
+ self.norm = LayerNorm(dim, eps=norm_eps)
+ self.mlp = nn.Sequential(
+ nn.Linear(dim, int(dim * mlp_ratio)),
+ QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
+ nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
+
+ def forward(self, x):
+ """
+ x: [B, L, C].
+ """
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
+
+ # compute query, key, value
+ q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
+ k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
+
+ # compute attention
+ x = flash_attention(q, k, v, version=2)
+ x = x.reshape(b, 1, c)
+
+ # output
+ x = self.proj(x)
+ x = F.dropout(x, self.proj_dropout, self.training)
+
+ # mlp
+ x = x + self.mlp(self.norm(x))
+ return x[:, 0]
+
+
+class VisionTransformer(nn.Module):
+
+ def __init__(self,
+ image_size=224,
+ patch_size=16,
+ dim=768,
+ mlp_ratio=4,
+ out_dim=512,
+ num_heads=12,
+ num_layers=12,
+ pool_type='token',
+ pre_norm=True,
+ post_norm=False,
+ activation='quick_gelu',
+ attn_dropout=0.0,
+ proj_dropout=0.0,
+ embedding_dropout=0.0,
+ norm_eps=1e-5):
+ if image_size % patch_size != 0:
+ print(
+ '[WARNING] image_size is not divisible by patch_size',
+ flush=True)
+ assert pool_type in ('token', 'token_fc', 'attn_pool')
+ out_dim = out_dim or dim
+ super().__init__()
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_patches = (image_size // patch_size)**2
+ self.dim = dim
+ self.mlp_ratio = mlp_ratio
+ self.out_dim = out_dim
+ self.num_heads = num_heads
+ self.num_layers = num_layers
+ self.pool_type = pool_type
+ self.post_norm = post_norm
+ self.norm_eps = norm_eps
+
+ # embeddings
+ gain = 1.0 / math.sqrt(dim)
+ self.patch_embedding = nn.Conv2d(
+ 3,
+ dim,
+ kernel_size=patch_size,
+ stride=patch_size,
+ bias=not pre_norm)
+ if pool_type in ('token', 'token_fc'):
+ self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
+ self.pos_embedding = nn.Parameter(gain * torch.randn(
+ 1, self.num_patches +
+ (1 if pool_type in ('token', 'token_fc') else 0), dim))
+ self.dropout = nn.Dropout(embedding_dropout)
+
+ # transformer
+ self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
+ self.transformer = nn.Sequential(*[
+ AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False,
+ activation, attn_dropout, proj_dropout, norm_eps)
+ for _ in range(num_layers)
+ ])
+ self.post_norm = LayerNorm(dim, eps=norm_eps)
+
+ # head
+ if pool_type == 'token':
+ self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
+ elif pool_type == 'token_fc':
+ self.head = nn.Linear(dim, out_dim)
+ elif pool_type == 'attn_pool':
+ self.head = AttentionPool(dim, mlp_ratio, num_heads, activation,
+ proj_dropout, norm_eps)
+
+ def forward(self, x, interpolation=False, use_31_block=False):
+ b = x.size(0)
+
+ # embeddings
+ x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
+ if self.pool_type in ('token', 'token_fc'):
+ x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1)
+ if interpolation:
+ e = pos_interpolate(self.pos_embedding, x.size(1))
+ else:
+ e = self.pos_embedding
+ x = self.dropout(x + e)
+ if self.pre_norm is not None:
+ x = self.pre_norm(x)
+
+ # transformer
+ if use_31_block:
+ x = self.transformer[:-1](x)
+ return x
+ else:
+ x = self.transformer(x)
+ return x
+
+
+class XLMRobertaWithHead(XLMRoberta):
+
+ def __init__(self, **kwargs):
+ self.out_dim = kwargs.pop('out_dim')
+ super().__init__(**kwargs)
+
+ # head
+ mid_dim = (self.dim + self.out_dim) // 2
+ self.head = nn.Sequential(
+ nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(),
+ nn.Linear(mid_dim, self.out_dim, bias=False))
+
+ def forward(self, ids):
+ # xlm-roberta
+ x = super().forward(ids)
+
+ # average pooling
+ mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
+ x = (x * mask).sum(dim=1) / mask.sum(dim=1)
+
+ # head
+ x = self.head(x)
+ return x
+
+
+class XLMRobertaCLIP(nn.Module):
+
+ def __init__(self,
+ embed_dim=1024,
+ image_size=224,
+ patch_size=14,
+ vision_dim=1280,
+ vision_mlp_ratio=4,
+ vision_heads=16,
+ vision_layers=32,
+ vision_pool='token',
+ vision_pre_norm=True,
+ vision_post_norm=False,
+ activation='gelu',
+ vocab_size=250002,
+ max_text_len=514,
+ type_size=1,
+ pad_id=1,
+ text_dim=1024,
+ text_heads=16,
+ text_layers=24,
+ text_post_norm=True,
+ text_dropout=0.1,
+ attn_dropout=0.0,
+ proj_dropout=0.0,
+ embedding_dropout=0.0,
+ norm_eps=1e-5):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.vision_dim = vision_dim
+ self.vision_mlp_ratio = vision_mlp_ratio
+ self.vision_heads = vision_heads
+ self.vision_layers = vision_layers
+ self.vision_pre_norm = vision_pre_norm
+ self.vision_post_norm = vision_post_norm
+ self.activation = activation
+ self.vocab_size = vocab_size
+ self.max_text_len = max_text_len
+ self.type_size = type_size
+ self.pad_id = pad_id
+ self.text_dim = text_dim
+ self.text_heads = text_heads
+ self.text_layers = text_layers
+ self.text_post_norm = text_post_norm
+ self.norm_eps = norm_eps
+
+ # models
+ self.visual = VisionTransformer(
+ image_size=image_size,
+ patch_size=patch_size,
+ dim=vision_dim,
+ mlp_ratio=vision_mlp_ratio,
+ out_dim=embed_dim,
+ num_heads=vision_heads,
+ num_layers=vision_layers,
+ pool_type=vision_pool,
+ pre_norm=vision_pre_norm,
+ post_norm=vision_post_norm,
+ activation=activation,
+ attn_dropout=attn_dropout,
+ proj_dropout=proj_dropout,
+ embedding_dropout=embedding_dropout,
+ norm_eps=norm_eps)
+ self.textual = XLMRobertaWithHead(
+ vocab_size=vocab_size,
+ max_seq_len=max_text_len,
+ type_size=type_size,
+ pad_id=pad_id,
+ dim=text_dim,
+ out_dim=embed_dim,
+ num_heads=text_heads,
+ num_layers=text_layers,
+ post_norm=text_post_norm,
+ dropout=text_dropout)
+ self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
+
+ def forward(self, imgs, txt_ids):
+ """
+ imgs: [B, 3, H, W] of torch.float32.
+ - mean: [0.48145466, 0.4578275, 0.40821073]
+ - std: [0.26862954, 0.26130258, 0.27577711]
+ txt_ids: [B, L] of torch.long.
+ Encoded by data.CLIPTokenizer.
+ """
+ xi = self.visual(imgs)
+ xt = self.textual(txt_ids)
+ return xi, xt
+
+ def param_groups(self):
+ groups = [{
+ 'params': [
+ p for n, p in self.named_parameters()
+ if 'norm' in n or n.endswith('bias')
+ ],
+ 'weight_decay': 0.0
+ }, {
+ 'params': [
+ p for n, p in self.named_parameters()
+ if not ('norm' in n or n.endswith('bias'))
+ ]
+ }]
+ return groups
+
+
+def _clip(pretrained=False,
+ pretrained_name=None,
+ model_cls=XLMRobertaCLIP,
+ return_transforms=False,
+ return_tokenizer=False,
+ tokenizer_padding='eos',
+ dtype=torch.float32,
+ device='cpu',
+ **kwargs):
+ # init a model on device
+ with torch.device(device):
+ model = model_cls(**kwargs)
+
+ # set device
+ model = model.to(dtype=dtype, device=device)
+ output = (model,)
+
+ # init transforms
+ if return_transforms:
+ # mean and std
+ if 'siglip' in pretrained_name.lower():
+ mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
+ else:
+ mean = [0.48145466, 0.4578275, 0.40821073]
+ std = [0.26862954, 0.26130258, 0.27577711]
+
+ # transforms
+ transforms = T.Compose([
+ T.Resize((model.image_size, model.image_size),
+ interpolation=T.InterpolationMode.BICUBIC),
+ T.ToTensor(),
+ T.Normalize(mean=mean, std=std)
+ ])
+ output += (transforms,)
+ return output[0] if len(output) == 1 else output
+
+
+def clip_xlm_roberta_vit_h_14(
+ pretrained=False,
+ pretrained_name='open-clip-xlm-roberta-large-vit-huge-14',
+ **kwargs):
+ cfg = dict(
+ embed_dim=1024,
+ image_size=224,
+ patch_size=14,
+ vision_dim=1280,
+ vision_mlp_ratio=4,
+ vision_heads=16,
+ vision_layers=32,
+ vision_pool='token',
+ activation='gelu',
+ vocab_size=250002,
+ max_text_len=514,
+ type_size=1,
+ pad_id=1,
+ text_dim=1024,
+ text_heads=16,
+ text_layers=24,
+ text_post_norm=True,
+ text_dropout=0.1,
+ attn_dropout=0.0,
+ proj_dropout=0.0,
+ embedding_dropout=0.0)
+ cfg.update(**kwargs)
+ return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)
+
+
+class CLIPModel:
+
+ def __init__(self, dtype, device, checkpoint_path, tokenizer_path):
+ self.dtype = dtype
+ self.device = device
+ self.checkpoint_path = checkpoint_path
+ self.tokenizer_path = tokenizer_path
+
+ # init model
+ self.model, self.transforms = clip_xlm_roberta_vit_h_14(
+ pretrained=False,
+ return_transforms=True,
+ return_tokenizer=False,
+ dtype=dtype,
+ device=device)
+ self.model = self.model.eval().requires_grad_(False)
+ logging.info(f'loading {checkpoint_path}')
+ self.model.load_state_dict(
+ torch.load(checkpoint_path, map_location='cpu'))
+
+ # init tokenizer
+ self.tokenizer = HuggingfaceTokenizer(
+ name=tokenizer_path,
+ seq_len=self.model.max_text_len - 2,
+ clean='whitespace')
+
+ def visual(self, videos):
+ # preprocess
+ size = (self.model.image_size,) * 2
+ videos = torch.cat([
+ F.interpolate(
+ u.transpose(0, 1),
+ size=size,
+ mode='bicubic',
+ align_corners=False) for u in videos
+ ])
+ videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
+
+ # forward
+ with torch.amp.autocast('cuda', dtype=self.dtype):
+ out = self.model.visual(videos, use_31_block=True)
+ return out
diff --git a/humo/models/wan_modules/model.py b/humo/models/wan_modules/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a9c0f3e5869d4e1d1f901f11e58e4678a850dab
--- /dev/null
+++ b/humo/models/wan_modules/model.py
@@ -0,0 +1,619 @@
+# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
+import math
+
+import torch
+import torch.cuda.amp as amp
+import torch.nn as nn
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.models.modeling_utils import ModelMixin
+
+from .attention import flash_attention
+
+__all__ = ['WanModel']
+
+
+def sinusoidal_embedding_1d(dim, position):
+ # preprocess
+ assert dim % 2 == 0
+ half = dim // 2
+ position = position.type(torch.float64)
+
+ # calculation
+ sinusoid = torch.outer(
+ position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
+ x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
+ return x
+
+
+@torch.amp.autocast("cuda", enabled=False)
+def rope_params(max_seq_len, dim, theta=10000):
+ assert dim % 2 == 0
+ freqs = torch.outer(
+ torch.arange(max_seq_len),
+ 1.0 / torch.pow(theta,
+ torch.arange(0, dim, 2).to(torch.float64).div(dim)))
+ freqs = torch.polar(torch.ones_like(freqs), freqs)
+ return freqs
+
+
+@torch.amp.autocast("cuda", enabled=False)
+def rope_apply(x, grid_sizes, freqs):
+ n, c = x.size(2), x.size(3) // 2
+
+ # split freqs
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
+
+ # loop over samples
+ output = []
+ for i, (f, h, w) in enumerate(grid_sizes.tolist()):
+ seq_len = f * h * w
+
+ # precompute multipliers
+ x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(
+ seq_len, n, -1, 2))
+ freqs_i = torch.cat([
+ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
+ freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
+ freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
+ ],
+ dim=-1).reshape(seq_len, 1, -1)
+
+ # apply rotary embedding
+ x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
+ x_i = torch.cat([x_i, x[i, seq_len:]])
+
+ # append to collection
+ output.append(x_i)
+ return torch.stack(output).float()
+
+
+class WanRMSNorm(nn.Module):
+
+ def __init__(self, dim, eps=1e-5):
+ super().__init__()
+ self.dim = dim
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(dim))
+
+ def forward(self, x):
+ r"""
+ Args:
+ x(Tensor): Shape [B, L, C]
+ """
+ return self._norm(x.float()).type_as(x) * self.weight
+
+ def _norm(self, x):
+ return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
+
+
+class WanLayerNorm(nn.LayerNorm):
+
+ def __init__(self, dim, eps=1e-6, elementwise_affine=False):
+ super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
+
+ def forward(self, x):
+ r"""
+ Args:
+ x(Tensor): Shape [B, L, C]
+ """
+ return super().forward(x.float()).type_as(x)
+
+
+class WanSelfAttention(nn.Module):
+
+ def __init__(self,
+ dim,
+ num_heads,
+ window_size=(-1, -1),
+ qk_norm=True,
+ eps=1e-6):
+ assert dim % num_heads == 0
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.window_size = window_size
+ self.qk_norm = qk_norm
+ self.eps = eps
+
+ # layers
+ self.q = nn.Linear(dim, dim)
+ self.k = nn.Linear(dim, dim)
+ self.v = nn.Linear(dim, dim)
+ self.o = nn.Linear(dim, dim)
+ self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
+ self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
+
+ def forward(self, x, seq_lens, grid_sizes, freqs):
+ r"""
+ Args:
+ x(Tensor): Shape [B, L, num_heads, C / num_heads]
+ seq_lens(Tensor): Shape [B]
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
+ """
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
+
+ # query, key, value function
+ def qkv_fn(x):
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
+ v = self.v(x).view(b, s, n, d)
+ return q, k, v
+
+ q, k, v = qkv_fn(x)
+
+ x = flash_attention(
+ q=rope_apply(q, grid_sizes, freqs),
+ k=rope_apply(k, grid_sizes, freqs),
+ v=v,
+ k_lens=seq_lens,
+ window_size=self.window_size)
+
+ # output
+ x = x.flatten(2)
+ x = self.o(x)
+ return x
+
+
+class WanT2VCrossAttention(WanSelfAttention):
+
+ def forward(self, x, context, context_lens):
+ r"""
+ Args:
+ x(Tensor): Shape [B, L1, C]
+ context(Tensor): Shape [B, L2, C]
+ context_lens(Tensor): Shape [B]
+ """
+ b, n, d = x.size(0), self.num_heads, self.head_dim
+
+ # compute query, key, value
+ q = self.norm_q(self.q(x)).view(b, -1, n, d)
+ k = self.norm_k(self.k(context)).view(b, -1, n, d)
+ v = self.v(context).view(b, -1, n, d)
+
+ # compute attention
+ x = flash_attention(q, k, v, k_lens=context_lens)
+
+ # output
+ x = x.flatten(2)
+ x = self.o(x)
+ return x
+
+
+class WanI2VCrossAttention(WanSelfAttention):
+
+ def __init__(self,
+ dim,
+ num_heads,
+ window_size=(-1, -1),
+ qk_norm=True,
+ eps=1e-6):
+ super().__init__(dim, num_heads, window_size, qk_norm, eps)
+
+ self.k_img = nn.Linear(dim, dim)
+ self.v_img = nn.Linear(dim, dim)
+ # self.alpha = nn.Parameter(torch.zeros((1, )))
+ self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
+
+ def forward(self, x, context, context_lens):
+ r"""
+ Args:
+ x(Tensor): Shape [B, L1, C]
+ context(Tensor): Shape [B, L2, C]
+ context_lens(Tensor): Shape [B]
+ """
+ context_img = context[:, :257]
+ context = context[:, 257:]
+ b, n, d = x.size(0), self.num_heads, self.head_dim
+
+ # compute query, key, value
+ q = self.norm_q(self.q(x)).view(b, -1, n, d)
+ k = self.norm_k(self.k(context)).view(b, -1, n, d)
+ v = self.v(context).view(b, -1, n, d)
+ k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
+ v_img = self.v_img(context_img).view(b, -1, n, d)
+ img_x = flash_attention(q, k_img, v_img, k_lens=None)
+ # compute attention
+ x = flash_attention(q, k, v, k_lens=context_lens)
+
+ # output
+ x = x.flatten(2)
+ img_x = img_x.flatten(2)
+ x = x + img_x
+ x = self.o(x)
+ return x
+
+
+WAN_CROSSATTENTION_CLASSES = {
+ 't2v_cross_attn': WanT2VCrossAttention,
+ 'i2v_cross_attn': WanI2VCrossAttention,
+}
+
+
+class WanAttentionBlock(nn.Module):
+
+ def __init__(self,
+ cross_attn_type,
+ dim,
+ ffn_dim,
+ num_heads,
+ window_size=(-1, -1),
+ qk_norm=True,
+ cross_attn_norm=False,
+ eps=1e-6):
+ super().__init__()
+ self.dim = dim
+ self.ffn_dim = ffn_dim
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.qk_norm = qk_norm
+ self.cross_attn_norm = cross_attn_norm
+ self.eps = eps
+
+ # layers
+ self.norm1 = WanLayerNorm(dim, eps)
+ self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
+ eps)
+ self.norm3 = WanLayerNorm(
+ dim, eps,
+ elementwise_affine=True) if cross_attn_norm else nn.Identity()
+ self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim,
+ num_heads,
+ (-1, -1),
+ qk_norm,
+ eps)
+ self.norm2 = WanLayerNorm(dim, eps)
+ self.ffn = nn.Sequential(
+ nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
+ nn.Linear(ffn_dim, dim))
+
+ # modulation
+ self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
+
+ def forward(
+ self,
+ x,
+ e,
+ seq_lens,
+ grid_sizes,
+ freqs,
+ context,
+ context_lens,
+ ):
+ r"""
+ Args:
+ x(Tensor): Shape [B, L, C]
+ e(Tensor): Shape [B, 6, C]
+ seq_lens(Tensor): Shape [B], length of each sequence in batch
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
+ """
+ assert e.dtype == torch.float32
+ with torch.amp.autocast('cuda', dtype=torch.float32):
+ e = (self.modulation + e).chunk(6, dim=1)
+ assert e[0].dtype == torch.float32
+
+ # self-attention
+ y = self.self_attn(
+ self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes,
+ freqs)
+ with torch.amp.autocast('cuda', dtype=torch.float32):
+ x = x + y * e[2]
+
+ # cross-attention & ffn function
+ def cross_attn_ffn(x, context, context_lens, e):
+ x = x + self.cross_attn(self.norm3(x), context, context_lens)
+ y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
+ with torch.amp.autocast('cuda', dtype=torch.float32):
+ x = x + y * e[5]
+ return x
+
+ x = cross_attn_ffn(x, context, context_lens, e)
+ return x
+
+
+class Head(nn.Module):
+
+ def __init__(self, dim, out_dim, patch_size, eps=1e-6):
+ super().__init__()
+ self.dim = dim
+ self.out_dim = out_dim
+ self.patch_size = patch_size
+ self.eps = eps
+
+ # layers
+ out_dim = math.prod(patch_size) * out_dim
+ self.norm = WanLayerNorm(dim, eps)
+ self.head = nn.Linear(dim, out_dim)
+
+ # modulation
+ self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
+
+ def forward(self, x, e):
+ r"""
+ Args:
+ x(Tensor): Shape [B, L1, C]
+ e(Tensor): Shape [B, C]
+ """
+ assert e.dtype == torch.float32
+ with torch.amp.autocast('cuda', dtype=torch.float32):
+ e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
+ x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
+ return x
+
+
+class MLPProj(torch.nn.Module):
+
+ def __init__(self, in_dim, out_dim):
+ super().__init__()
+
+ self.proj = torch.nn.Sequential(
+ torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim),
+ torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim),
+ torch.nn.LayerNorm(out_dim))
+
+ def forward(self, image_embeds):
+ clip_extra_context_tokens = self.proj(image_embeds)
+ return clip_extra_context_tokens
+
+
+class WanModel(ModelMixin, ConfigMixin):
+ r"""
+ Wan diffusion backbone supporting both text-to-video and image-to-video.
+ """
+
+ ignore_for_config = [
+ 'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size'
+ ]
+ _no_split_modules = ['WanAttentionBlock']
+
+ @register_to_config
+ def __init__(self,
+ model_type='t2v',
+ patch_size=(1, 2, 2),
+ text_len=512,
+ in_dim=16,
+ dim=5120,
+ ffn_dim=13824,
+ freq_dim=256,
+ text_dim=4096,
+ out_dim=16,
+ num_heads=40,
+ num_layers=40,
+ window_size=(-1, -1),
+ qk_norm=True,
+ cross_attn_norm=True,
+ eps=1e-6):
+ r"""
+ Initialize the diffusion model backbone.
+
+ Args:
+ model_type (`str`, *optional*, defaults to 't2v'):
+ Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
+ patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
+ 3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
+ text_len (`int`, *optional*, defaults to 512):
+ Fixed length for text embeddings
+ in_dim (`int`, *optional*, defaults to 16):
+ Input video channels (C_in)
+ dim (`int`, *optional*, defaults to 2048):
+ Hidden dimension of the transformer
+ ffn_dim (`int`, *optional*, defaults to 8192):
+ Intermediate dimension in feed-forward network
+ freq_dim (`int`, *optional*, defaults to 256):
+ Dimension for sinusoidal time embeddings
+ text_dim (`int`, *optional*, defaults to 4096):
+ Input dimension for text embeddings
+ out_dim (`int`, *optional*, defaults to 16):
+ Output video channels (C_out)
+ num_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads
+ num_layers (`int`, *optional*, defaults to 32):
+ Number of transformer blocks
+ window_size (`tuple`, *optional*, defaults to (-1, -1)):
+ Window size for local attention (-1 indicates global attention)
+ qk_norm (`bool`, *optional*, defaults to True):
+ Enable query/key normalization
+ cross_attn_norm (`bool`, *optional*, defaults to False):
+ Enable cross-attention normalization
+ eps (`float`, *optional*, defaults to 1e-6):
+ Epsilon value for normalization layers
+ """
+
+ super().__init__()
+
+ assert model_type in ['t2v', 'i2v']
+ self.model_type = model_type
+
+ self.patch_size = patch_size
+ self.text_len = text_len
+ self.in_dim = in_dim
+ self.dim = dim
+ self.ffn_dim = ffn_dim
+ self.freq_dim = freq_dim
+ self.text_dim = text_dim
+ self.out_dim = out_dim
+ self.num_heads = num_heads
+ self.num_layers = num_layers
+ self.window_size = window_size
+ self.qk_norm = qk_norm
+ self.cross_attn_norm = cross_attn_norm
+ self.eps = eps
+
+ # embeddings
+ self.patch_embedding = nn.Conv3d(
+ in_dim, dim, kernel_size=patch_size, stride=patch_size)
+ self.text_embedding = nn.Sequential(
+ nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
+ nn.Linear(dim, dim))
+
+ self.time_embedding = nn.Sequential(
+ nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
+ self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
+
+ # blocks
+ cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
+ self.blocks = nn.ModuleList([
+ WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
+ window_size, qk_norm, cross_attn_norm, eps)
+ for _ in range(num_layers)
+ ])
+
+ # head
+ self.head = Head(dim, out_dim, patch_size, eps)
+
+ # buffers (don't use register_buffer otherwise dtype will be changed in to())
+ assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
+ d = dim // num_heads
+ self.freqs = torch.cat([
+ rope_params(1024, d - 4 * (d // 6)),
+ rope_params(1024, 2 * (d // 6)),
+ rope_params(1024, 2 * (d // 6))
+ ],
+ dim=1)
+
+ if model_type == 'i2v':
+ self.img_emb = MLPProj(1280, dim)
+
+ # initialize weights
+ self.init_weights()
+
+ def forward(
+ self,
+ x,
+ t,
+ context,
+ seq_len,
+ clip_fea=None,
+ y=None,
+ ):
+ r"""
+ Forward pass through the diffusion model
+
+ Args:
+ x (List[Tensor]):
+ List of input video tensors, each with shape [C_in, F, H, W]
+ t (Tensor):
+ Diffusion timesteps tensor of shape [B]
+ context (List[Tensor]):
+ List of text embeddings each with shape [L, C]
+ seq_len (`int`):
+ Maximum sequence length for positional encoding
+ clip_fea (Tensor, *optional*):
+ CLIP image features for image-to-video mode
+ y (List[Tensor], *optional*):
+ Conditional video inputs for image-to-video mode, same shape as x
+
+ Returns:
+ List[Tensor]:
+ List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
+ """
+ if self.model_type == 'i2v':
+ assert clip_fea is not None and y is not None
+ # params
+ device = self.patch_embedding.weight.device
+ freqs = self.freqs.to(device)
+
+ if y is not None:
+ x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
+
+ # embeddings
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
+ grid_sizes = torch.stack(
+ [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
+ x = [u.flatten(2).transpose(1, 2) for u in x]
+ seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
+ assert seq_lens.max() <= seq_len
+ x = torch.cat([
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
+ dim=1) for u in x
+ ])
+
+ # time embeddings
+ with torch.amp.autocast('cuda', dtype=torch.float32):
+ e = self.time_embedding(
+ sinusoidal_embedding_1d(self.freq_dim, t).float())
+ e0 = self.time_projection(e).unflatten(1, (6, self.dim))
+ assert e.dtype == torch.float32 and e0.dtype == torch.float32
+
+ # context
+ context_lens = None
+ context = self.text_embedding(
+ torch.stack([
+ torch.cat(
+ [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
+ for u in context
+ ]))
+
+ if clip_fea is not None:
+ context_clip = self.img_emb(clip_fea) # bs x 257 x dim
+ context = torch.concat([context_clip, context], dim=1)
+
+ # arguments
+ kwargs = dict(
+ e=e0,
+ seq_lens=seq_lens,
+ grid_sizes=grid_sizes,
+ freqs=freqs,
+ context=context,
+ context_lens=context_lens)
+
+ for block in self.blocks:
+ x = block(x, **kwargs)
+
+ # head
+ x = self.head(x, e)
+
+ # unpatchify
+ x = self.unpatchify(x, grid_sizes)
+ return [u.float() for u in x]
+
+ def unpatchify(self, x, grid_sizes):
+ r"""
+ Reconstruct video tensors from patch embeddings.
+
+ Args:
+ x (List[Tensor]):
+ List of patchified features, each with shape [L, C_out * prod(patch_size)]
+ grid_sizes (Tensor):
+ Original spatial-temporal grid dimensions before patching,
+ shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
+
+ Returns:
+ List[Tensor]:
+ Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
+ """
+
+ c = self.out_dim
+ out = []
+ for u, v in zip(x, grid_sizes.tolist()):
+ u = u[:math.prod(v)].view(*v, *self.patch_size, c)
+ u = torch.einsum('fhwpqrc->cfphqwr', u)
+ u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
+ out.append(u)
+ return out
+
+ def init_weights(self):
+ r"""
+ Initialize model parameters using Xavier initialization.
+ """
+
+ # basic init
+ for m in self.modules():
+ if isinstance(m, nn.Linear):
+ nn.init.xavier_uniform_(m.weight)
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+
+ # init embeddings
+ nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
+ for m in self.text_embedding.modules():
+ if isinstance(m, nn.Linear):
+ nn.init.normal_(m.weight, std=.02)
+ for m in self.time_embedding.modules():
+ if isinstance(m, nn.Linear):
+ nn.init.normal_(m.weight, std=.02)
+
+ # init output layer
+ nn.init.zeros_(self.head.head.weight)
diff --git a/humo/models/wan_modules/model_humo.py b/humo/models/wan_modules/model_humo.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7190fe08b49f8c05e7f29dcc956fba4bd7ecaa5
--- /dev/null
+++ b/humo/models/wan_modules/model_humo.py
@@ -0,0 +1,803 @@
+import torch
+from torch import nn
+
+from common.distributed import get_device
+from models.audio.audio_proj import AudioProjModel
+
+import torch.cuda.amp as amp
+import math
+from humo.models.wan_modules.attention import flash_attention
+from common.distributed.advanced import is_unified_parallel_initialized
+
+import types
+
+def sinusoidal_embedding_1d(dim, position):
+ # preprocess
+ assert dim % 2 == 0
+ half = dim // 2
+ position = position.type(torch.float64)
+
+ # calculation
+ sinusoid = torch.outer(
+ position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
+ x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
+ return x
+
+
+@amp.autocast(enabled=False)
+def rope_params(max_seq_len, dim, theta=10000):
+ assert dim % 2 == 0
+ freqs = torch.outer(
+ torch.arange(max_seq_len),
+ 1.0 / torch.pow(theta,
+ torch.arange(0, dim, 2).to(torch.float32).div(dim)))
+ freqs = torch.polar(torch.ones_like(freqs), freqs)
+ return freqs
+
+
+@amp.autocast(enabled=False)
+def rope_apply(x, grid_sizes, freqs):
+ n, c = x.size(2), x.size(3) // 2
+
+ # split freqs
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
+
+ # loop over samples
+ output = []
+ for i, (f, h, w) in enumerate(grid_sizes.tolist()):
+ seq_len = f * h * w
+
+ # precompute multipliers
+ x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float32).reshape(
+ seq_len, n, -1, 2))
+ freqs_i = torch.cat([
+ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
+ freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
+ freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
+ ],
+ dim=-1).reshape(seq_len, 1, -1)
+
+ # apply rotary embedding
+ x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
+ x_i = torch.cat([x_i, x[i, seq_len:]])
+
+ # append to collection
+ output.append(x_i)
+ return torch.stack(output).float()
+
+
+class WanRMSNorm(nn.Module):
+
+ def __init__(self, dim, eps=1e-5):
+ super().__init__()
+ self.dim = dim
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(dim))
+
+ def forward(self, x):
+ r"""
+ Args:
+ x(Tensor): Shape [B, L, C]
+ """
+ return self._norm(x.float()).type_as(x) * self.weight
+
+ def _norm(self, x):
+ return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
+
+
+class WanLayerNorm(nn.LayerNorm):
+
+ def __init__(self, dim, eps=1e-6, elementwise_affine=False):
+ super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
+
+ def forward(self, x):
+ r"""
+ Args:
+ x(Tensor): Shape [B, L, C]
+ """
+ return super().forward(x.float()).type_as(x)
+
+
+class WanSelfAttention(nn.Module):
+
+ def __init__(self,
+ dim,
+ num_heads,
+ window_size=(-1, -1),
+ qk_norm=True,
+ eps=1e-6):
+ assert dim % num_heads == 0
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.window_size = window_size
+ self.qk_norm = qk_norm
+ self.eps = eps
+
+ # layers
+ self.q = nn.Linear(dim, dim)
+ self.k = nn.Linear(dim, dim)
+ self.v = nn.Linear(dim, dim)
+ self.o = nn.Linear(dim, dim)
+ self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
+ self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
+
+ def forward(self, x, seq_lens, grid_sizes, freqs):
+ r"""
+ Args:
+ x(Tensor): Shape [B, L, num_heads, C / num_heads], torch.Size([1, 9360, 5120])
+ seq_lens(Tensor): Shape [B], tensor([9360])
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W), tensor([[ 6, 30, 52]])
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
+ """
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
+
+ # query, key, value function
+ def qkv_fn(x):
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
+ v = self.v(x).view(b, s, n, d)
+ return q, k, v
+
+ q, k, v = qkv_fn(x)
+
+ x = flash_attention(
+ q=rope_apply(q, grid_sizes, freqs),
+ k=rope_apply(k, grid_sizes, freqs),
+ v=v,
+ k_lens=seq_lens,
+ window_size=self.window_size)
+
+ # output
+ x = x.flatten(2)
+ x = self.o(x)
+ return x
+
+
+class WanSelfAttentionSepKVDim(nn.Module):
+
+ def __init__(self,
+ kv_dim,
+ dim,
+ num_heads,
+ window_size=(-1, -1),
+ qk_norm=True,
+ eps=1e-6):
+ assert dim % num_heads == 0
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.window_size = window_size
+ self.qk_norm = qk_norm
+ self.eps = eps
+
+ # layers
+ self.q = nn.Linear(dim, dim)
+ self.k = nn.Linear(kv_dim, dim)
+ self.v = nn.Linear(kv_dim, dim)
+ self.o = nn.Linear(dim, dim)
+ self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
+ self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
+
+ def forward(self, x, seq_lens, grid_sizes, freqs):
+ r"""
+ Args:
+ x(Tensor): Shape [B, L, num_heads, C / num_heads], torch.Size([1, 9360, 5120])
+ seq_lens(Tensor): Shape [B], tensor([9360])
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W), tensor([[ 6, 30, 52]])
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
+ """
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
+
+ # query, key, value function
+ def qkv_fn(x):
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
+ v = self.v(x).view(b, s, n, d)
+ return q, k, v
+
+ q, k, v = qkv_fn(x)
+
+ x = flash_attention(
+ q=rope_apply(q, grid_sizes, freqs),
+ k=rope_apply(k, grid_sizes, freqs),
+ v=v,
+ k_lens=seq_lens,
+ window_size=self.window_size)
+
+ # output
+ x = x.flatten(2)
+ x = self.o(x)
+ return x
+
+
+
+class WanT2VCrossAttention(WanSelfAttention):
+
+ def forward(self, x, context, context_lens):
+ r"""
+ Args:
+ x(Tensor): Shape [B, L1, C]
+ context(Tensor): Shape [B, L2, C]
+ context_lens(Tensor): Shape [B]
+ """
+ b, n, d = x.size(0), self.num_heads, self.head_dim
+
+ # compute query, key, value
+ q = self.norm_q(self.q(x)).view(b, -1, n, d)
+ k = self.norm_k(self.k(context)).view(b, -1, n, d)
+ v = self.v(context).view(b, -1, n, d)
+
+ # compute attention
+ x = flash_attention(q, k, v, k_lens=context_lens)
+
+ # output
+ x = x.flatten(2)
+ x = self.o(x)
+ return x
+
+
+class WanT2VCrossAttentionGather(WanSelfAttentionSepKVDim):
+
+ def forward(self, x, context, context_lens, grid_sizes, freqs, audio_seq_len):
+ b, n, d = x.size(0), self.num_heads, self.head_dim
+
+ q = self.norm_q(self.q(x)).view(b, -1, n, d)
+ k = self.norm_k(self.k(context)).view(b, -1, n, d)
+ v = self.v(context).view(b, -1, n, d)
+
+ # --- NEW: derive sizes from shapes (SymInts), no int(tensor) casts ---
+ Lq = q.shape[1] # total video tokens per sample
+ # audio has 16 tokens per frame -> frames = audio_tokens // 16
+ frames = (context.shape[1] // 16)
+ hlen_wlen = Lq // frames # tokens per frame = H*W
+
+ # Now reshape using SymInt-derived sizes
+ q = q.reshape(-1, hlen_wlen, n, d)
+ k = k.reshape(-1, 16, n, d)
+ v = v.reshape(-1, 16, n, d)
+
+ x = flash_attention(q, k, v, k_lens=None)
+ x = x.view(b, -1, n, d).flatten(2)
+ x = self.o(x)
+ return x
+
+ # def forward(self, x, context, context_lens, grid_sizes, freqs, audio_seq_len):
+ # r"""
+ # Args:
+ # x(Tensor): Shape [B, L1, C] - video tokens
+ # context(Tensor): Shape [B, L2, C] - audio tokens with shape [B, frames*16, 1536]
+ # context_lens(Tensor): Shape [B] - actually seq_lens from call (video sequence length)
+ # grid_sizes(Tensor): Shape [B, 3] - video grid dimensions (F, H, W)
+ # freqs(Tensor): RoPE frequencies
+ # audio_seq_len(Tensor): Actual audio sequence length (frames * 16)
+ # """
+ # b, n, d = x.size(0), self.num_heads, self.head_dim
+
+ # q = self.norm_q(self.q(x)).view(b, -1, n, d)
+ # k = self.norm_k(self.k(context)).view(b, -1, n, d)
+ # v = self.v(context).view(b, -1, n, d)
+
+ # # Handle video spatial structure
+ # hlen_wlen = int(grid_sizes[0][1] * grid_sizes[0][2])
+ # q = q.reshape(-1, hlen_wlen, n, d)
+
+ # # Handle audio temporal structure (16 tokens per frame)
+ # k = k.reshape(-1, 16, n, d)
+ # v = v.reshape(-1, 16, n, d)
+
+ # # Cross-attention
+ # x = flash_attention(q, k, v, k_lens=None) # No masking for audio
+
+ # x = x.view(b, -1, n, d).flatten(2)
+ # x = self.o(x)
+ # return x
+
+
+class AudioCrossAttentionWrapper(nn.Module):
+ def __init__(self, dim, kv_dim, num_heads, qk_norm=True, eps=1e-6,):
+ super().__init__()
+
+ self.audio_cross_attn = WanT2VCrossAttentionGather(
+ kv_dim, dim, num_heads, (-1, -1), qk_norm, eps)
+ self.norm1_audio = WanLayerNorm(dim, eps,
+ elementwise_affine=True)
+
+ def forward(self, x, audio, seq_lens, grid_sizes, freqs, audio_seq_len):
+ x = x + self.audio_cross_attn(
+ self.norm1_audio(x), audio, seq_lens, grid_sizes, freqs, audio_seq_len)
+ return x
+
+
+class WanI2VCrossAttention(WanSelfAttention):
+
+ def __init__(self,
+ dim,
+ num_heads,
+ window_size=(-1, -1),
+ qk_norm=True,
+ eps=1e-6):
+ super().__init__(dim, num_heads, window_size, qk_norm, eps)
+
+ def forward(self, x, context, context_lens):
+ r"""
+ Args:
+ x(Tensor): Shape [B, L1, C]
+ context(Tensor): Shape [B, L2, C]
+ context_lens(Tensor): Shape [B]
+ """
+ b, n, d = x.size(0), self.num_heads, self.head_dim
+
+ # compute query, key, value
+ q = self.norm_q(self.q(x)).view(b, -1, n, d)
+ k = self.norm_k(self.k(context)).view(b, -1, n, d)
+ v = self.v(context).view(b, -1, n, d)
+ x = flash_attention(q, k, v, k_lens=context_lens)
+
+ # output
+ x = x.flatten(2)
+ x = self.o(x)
+ return x
+
+
+WAN_CROSSATTENTION_CLASSES = {
+ 't2v_cross_attn': WanT2VCrossAttention,
+ 'i2v_cross_attn': WanI2VCrossAttention,
+}
+
+class WanAttentionBlock(nn.Module):
+
+ def __init__(self,
+ cross_attn_type,
+ dim,
+ ffn_dim,
+ num_heads,
+ window_size=(-1, -1),
+ qk_norm=True,
+ cross_attn_norm=False,
+ eps=1e-6,
+ use_audio=True):
+ super().__init__()
+ self.dim = dim
+ self.ffn_dim = ffn_dim
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.qk_norm = qk_norm
+ self.cross_attn_norm = cross_attn_norm
+ self.eps = eps
+
+ # layers
+ self.norm1 = WanLayerNorm(dim, eps)
+ self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
+ eps)
+ self.norm3 = WanLayerNorm(
+ dim, eps,
+ elementwise_affine=True) if cross_attn_norm else nn.Identity()
+ self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim,
+ num_heads,
+ (-1, -1),
+ qk_norm,
+ eps)
+ self.norm2 = WanLayerNorm(dim, eps)
+ self.ffn = nn.Sequential(
+ nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
+ nn.Linear(ffn_dim, dim))
+
+ # modulation
+ self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
+
+ self.use_audio = use_audio
+ if use_audio:
+ self.audio_cross_attn_wrapper = AudioCrossAttentionWrapper(dim, 1536, num_heads, qk_norm, eps)
+
+ def forward(
+ self,
+ x, # torch.Size([1, 9360, 5120])
+ e, # torch.Size([1, 6, 5120])
+ seq_lens, # tensor([9360])
+ grid_sizes, # tensor([[ 6, 30, 52]])
+ freqs, # torch.Size([1024, 64])
+ context, # torch.Size([1, 512, 5120])
+ context_lens, # None
+ audio=None, # None
+ audio_seq_len=None,
+ ref_num_list=None,
+ ):
+ r"""
+ Args:
+ x(Tensor): Shape [B, L, C]
+ e(Tensor): Shape [B, L, C]
+ audio(Tensor): Shape [B, L, C]
+ seq_lens(Tensor): Shape [B], length of each sequence in batch
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
+ ref_num_list: 配合seq_lens可以查到reference image在倒数第几个
+ """
+ assert e.dtype == torch.float32
+ with torch.amp.autocast('cuda', dtype=torch.float32):
+ e = (self.modulation + e).chunk(6, dim=1)
+ assert e[0].dtype == torch.float32
+
+ # self-attention
+ y = self.self_attn(
+ self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes,
+ freqs)
+ with torch.amp.autocast('cuda', dtype=torch.float32):
+ x = x + y * e[2]
+
+ # cross-attention & ffn function
+ def cross_attn_ffn(x, context, context_lens, e):
+ x = x + self.cross_attn(self.norm3(x), context, context_lens)
+
+ if self.use_audio:
+ x = self.audio_cross_attn_wrapper(x, audio, seq_lens, grid_sizes, freqs, audio_seq_len)
+
+ y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
+ with torch.amp.autocast('cuda', dtype=torch.float32):
+ x = x + y * e[5]
+ return x
+
+ x = cross_attn_ffn(x, context, context_lens, e)
+
+ return x
+
+
+class Head(nn.Module):
+
+ def __init__(self, dim, out_dim, patch_size, eps=1e-6):
+ super().__init__()
+ self.dim = dim
+ self.out_dim = out_dim
+ self.patch_size = patch_size
+ self.eps = eps
+
+ # layers
+ out_dim = math.prod(patch_size) * out_dim
+ self.norm = WanLayerNorm(dim, eps)
+ self.head = nn.Linear(dim, out_dim)
+
+ # modulation
+ self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
+
+ def forward(self, x, e):
+ r"""
+ Args:
+ x(Tensor): Shape [B, L1, C]
+ e(Tensor): Shape [B, C]
+ """
+ assert e.dtype == torch.float32
+ with torch.amp.autocast('cuda', dtype=torch.float32):
+ e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
+ x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
+ return x
+
+
+class MLPProj(torch.nn.Module):
+
+ def __init__(self, in_dim, out_dim):
+ super().__init__()
+
+ self.proj = torch.nn.Sequential(
+ torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim),
+ torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim),
+ torch.nn.LayerNorm(out_dim))
+
+ def forward(self, image_embeds):
+ clip_extra_context_tokens = self.proj(image_embeds)
+ return clip_extra_context_tokens
+
+
+class WanModel(nn.Module):
+ r"""
+ Wan diffusion backbone supporting both text-to-video and image-to-video.
+ """
+
+ ignore_for_config = [
+ 'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size'
+ ]
+ _no_split_modules = ['WanAttentionBlock']
+
+ gradient_checkpointing = False
+
+ def __init__(self,
+ model_type='t2v',
+ patch_size=(1, 2, 2),
+ text_len=512,
+ in_dim=16,
+ dim=2048,
+ ffn_dim=13824,
+ freq_dim=256,
+ text_dim=4096,
+ out_dim=16,
+ num_heads=40,
+ num_layers=40,
+ window_size=(-1, -1),
+ qk_norm=True,
+ cross_attn_norm=True,
+ eps=1e-6,
+ audio_token_num=16,
+ insert_audio=True):
+ r"""
+ Initialize the diffusion model backbone.
+
+ Args:
+ model_type (`str`, *optional*, defaults to 't2v'):
+ Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
+ patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
+ 3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
+ text_len (`int`, *optional*, defaults to 512):
+ Fixed length for text embeddings
+ in_dim (`int`, *optional*, defaults to 16):
+ Input video channels (C_in)
+ dim (`int`, *optional*, defaults to 2048):
+ Hidden dimension of the transformer
+ ffn_dim (`int`, *optional*, defaults to 8192):
+ Intermediate dimension in feed-forward network
+ freq_dim (`int`, *optional*, defaults to 256):
+ Dimension for sinusoidal time embeddings
+ text_dim (`int`, *optional*, defaults to 4096):
+ Input dimension for text embeddings
+ out_dim (`int`, *optional*, defaults to 16):
+ Output video channels (C_out)
+ num_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads
+ num_layers (`int`, *optional*, defaults to 32):
+ Number of transformer blocks
+ window_size (`tuple`, *optional*, defaults to (-1, -1)):
+ Window size for local attention (-1 indicates global attention)
+ qk_norm (`bool`, *optional*, defaults to True):
+ Enable query/key normalization
+ cross_attn_norm (`bool`, *optional*, defaults to False):
+ Enable cross-attention normalization
+ eps (`float`, *optional*, defaults to 1e-6):
+ Epsilon value for normalization layers
+ """
+
+ super().__init__()
+
+ assert model_type in ['t2v', 'i2v']
+ self.model_type = model_type
+
+ self.patch_size = patch_size
+ self.text_len = text_len
+ self.in_dim = in_dim
+ self.dim = dim
+ self.ffn_dim = ffn_dim
+ self.freq_dim = freq_dim
+ self.text_dim = text_dim
+ self.out_dim = out_dim
+ self.num_heads = num_heads
+ self.num_layers = num_layers
+ self.window_size = window_size
+ self.qk_norm = qk_norm
+ self.cross_attn_norm = cross_attn_norm
+ self.eps = eps
+
+ # embeddings
+ self.patch_embedding = nn.Conv3d(
+ in_dim, dim, kernel_size=patch_size, stride=patch_size)
+ self.text_embedding = nn.Sequential(
+ nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
+ nn.Linear(dim, dim))
+
+ self.time_embedding = nn.Sequential(
+ nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
+ self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
+
+ # blocks
+ cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
+ self.insert_audio = insert_audio
+ self.blocks = nn.ModuleList([
+ WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
+ window_size, qk_norm, cross_attn_norm,
+ eps, use_audio=self.insert_audio)
+ for _ in range(num_layers)
+ ])
+
+ # head
+ self.head = Head(dim, out_dim, patch_size, eps)
+
+ if self.insert_audio:
+ self.audio_proj = AudioProjModel(seq_len=8, blocks=5, channels=1280,
+ intermediate_dim=512, output_dim=1536, context_tokens=audio_token_num)
+
+ # RoPE freqs: register as a buffer so it moves with .to() / DDP and is tracked by compile
+ assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
+ d = dim // num_heads
+
+ _freqs = torch.cat([
+ rope_params(1024, d - 4 * (d // 6)),
+ rope_params(1024, 2 * (d // 6)),
+ rope_params(1024, 2 * (d // 6))
+ ], dim=1)
+ self.register_buffer("freqs", _freqs, persistent=False)
+
+ # initialize weights
+ self.init_weights()
+
+ # initialize unified parallel
+ if is_unified_parallel_initialized():
+ print(f"Initializing WanModel with unified parallel initialized")
+ from humo.models.distributed.dit_ulysses_sequence_parallel import ulysses_attn_forward, ulysses_dit_forward, ulysses_audio_cross_attn_forward
+ for block in self.blocks:
+ block.self_attn.forward = types.MethodType(ulysses_attn_forward, block.self_attn)
+ if block.use_audio:
+ block.audio_cross_attn_wrapper.audio_cross_attn.forward = types.MethodType(ulysses_audio_cross_attn_forward, block.audio_cross_attn_wrapper.audio_cross_attn)
+ self.forward = types.MethodType(ulysses_dit_forward, self)
+
+ def forward(
+ self,
+ x,
+ t,
+ context,
+ seq_len,
+ audio=None,
+ y=None,
+ tea_cache=None,
+ ):
+ r"""
+ Forward pass through the diffusion model
+
+ Args:
+ x (List[Tensor]):
+ List of input video tensors, each with shape [C_in, F, H, W]
+ t (Tensor):
+ Diffusion timesteps tensor of shape [B]
+ context (List[Tensor]):
+ List of text embeddings each with shape [L, C]
+ seq_len (`int`):
+ Maximum sequence length for positional encoding
+ clip_fea (Tensor, *optional*):
+ CLIP image features for image-to-video mode
+ y (List[Tensor], *optional*):
+ Conditional video inputs for image-to-video mode, same shape as x
+
+ Returns:
+ List[Tensor]:
+ List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
+ """
+ if self.model_type == 'i2v':
+ assert y is not None
+
+ # params
+ freqs = self.freqs
+
+ if y is not None:
+ x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
+
+ # embeddings
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
+ grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
+
+ x = [u.flatten(2).transpose(1, 2) for u in x]
+ seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
+ assert seq_lens.max() <= seq_len
+
+ # pad to uniform length and batch
+ x = torch.cat([
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
+ for u in x
+ ]) # shape: [B, seq_len, C]
+
+ # time embeddings
+ with torch.amp.autocast('cuda', dtype=torch.float32):
+ e = self.time_embedding(
+ sinusoidal_embedding_1d(self.freq_dim, t).float()
+ ).float()
+ e0 = self.time_projection(e).unflatten(1, (6, self.dim)).float()
+ assert e.dtype == torch.float32 and e0.dtype == torch.float32
+
+ # context
+ context_lens = None
+ context = self.text_embedding(
+ torch.stack([
+ torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
+ for u in context
+ ])
+ )
+
+ # audio (unchanged; not cached)
+ if self.insert_audio:
+ audio = [self.audio_proj(au.unsqueeze(0)).permute(0, 3, 1, 2) for au in audio]
+ audio_seq_len = max(au.shape[2] for au in audio) * audio[0].shape[3]
+
+ audio = [au.flatten(2).transpose(1, 2) for au in audio] # [1, t*32, 1536]
+ audio = torch.cat([
+ torch.cat([au, au.new_zeros(1, int(audio_seq_len) - au.size(1), au.size(2))], dim=1)
+ for au in audio
+ ])
+ else:
+ audio = None
+ audio_seq_len = None
+
+ # ---- tea_cache integration (mirrors your working model) ----
+ if tea_cache is not None:
+ # Use the pre-block tokens 'x' and time-mod 'e0' to decide whether to reuse cache
+ tea_cache_update = tea_cache.check(self, x, e0)
+ else:
+ tea_cache_update = False
+
+ ori_x_len = x.shape[1] # remember original token length before potential cache extension
+
+ if tea_cache_update:
+ # Let the cache inject/append any needed past states/tokens for reuse
+ x = tea_cache.update(x)
+ else:
+ # arguments for blocks
+ kwargs = dict(
+ e=e0,
+ seq_lens=seq_lens,
+ grid_sizes=grid_sizes,
+ freqs=freqs,
+ context=context,
+ context_lens=context_lens,
+ audio=audio,
+ audio_seq_len=audio_seq_len
+ )
+
+ # transformer blocks
+ for block in self.blocks:
+ x = block(x, **kwargs)
+
+ if tea_cache is not None:
+ x_cache = x[:, :ori_x_len]
+ tea_cache.store(x_cache)
+
+ # head
+ x = self.head(x, e)
+
+ # unpatchify
+ x = self.unpatchify(x, grid_sizes)
+ return [u.float() for u in x]
+
+
+ def unpatchify(self, x, grid_sizes):
+ r"""
+ Reconstruct video tensors from patch embeddings.
+
+ Args:
+ x (List[Tensor]):
+ List of patchified features, each with shape [L, C_out * prod(patch_size)]
+ grid_sizes (Tensor):
+ Original spatial-temporal grid dimensions before patching,
+ shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
+
+ Returns:
+ List[Tensor]:
+ Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
+ """
+
+ c = self.out_dim
+ out = []
+ for u, v in zip(x, grid_sizes.tolist()):
+ u = u[:math.prod(v)].view(*v, *self.patch_size, c)
+ u = torch.einsum('fhwpqrc->cfphqwr', u)
+ u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
+ out.append(u)
+ return out
+
+ def init_weights(self):
+ r"""
+ Initialize model parameters using Xavier initialization.
+ """
+
+ # basic init
+ for m in self.modules():
+ if isinstance(m, nn.Linear):
+ nn.init.xavier_uniform_(m.weight)
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+
+ # init embeddings
+ nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
+ for m in self.text_embedding.modules():
+ if isinstance(m, nn.Linear):
+ nn.init.normal_(m.weight, std=.02)
+ for m in self.time_embedding.modules():
+ if isinstance(m, nn.Linear):
+ nn.init.normal_(m.weight, std=.02)
+
+ # init output layer
+ nn.init.zeros_(self.head.head.weight)
diff --git a/humo/models/wan_modules/t5.py b/humo/models/wan_modules/t5.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6005c88e95208881064bc8b3bf6fe8df6c4126f
--- /dev/null
+++ b/humo/models/wan_modules/t5.py
@@ -0,0 +1,525 @@
+# Modified from transformers.models.t5.modeling_t5
+# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
+import logging
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .tokenizers import HuggingfaceTokenizer
+
+__all__ = [
+ 'T5Model',
+ 'T5Encoder',
+ 'T5Decoder',
+ 'T5EncoderModel',
+]
+
+
+def fp16_clamp(x):
+ if x.dtype == torch.float16 and torch.isinf(x).any():
+ clamp = torch.finfo(x.dtype).max - 1000
+ x = torch.clamp(x, min=-clamp, max=clamp)
+ return x
+
+
+def init_weights(m):
+ if isinstance(m, T5LayerNorm):
+ nn.init.ones_(m.weight)
+ elif isinstance(m, T5Model):
+ nn.init.normal_(m.token_embedding.weight, std=1.0)
+ elif isinstance(m, T5FeedForward):
+ nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
+ nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
+ nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
+ elif isinstance(m, T5Attention):
+ nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5)
+ nn.init.normal_(m.k.weight, std=m.dim**-0.5)
+ nn.init.normal_(m.v.weight, std=m.dim**-0.5)
+ nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5)
+ elif isinstance(m, T5RelativeEmbedding):
+ nn.init.normal_(
+ m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5)
+
+
+class GELU(nn.Module):
+
+ def forward(self, x):
+ return 0.5 * x * (1.0 + torch.tanh(
+ math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
+
+
+class T5LayerNorm(nn.Module):
+
+ def __init__(self, dim, eps=1e-6):
+ super(T5LayerNorm, self).__init__()
+ self.dim = dim
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(dim))
+
+ def forward(self, x):
+ x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +
+ self.eps)
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
+ x = x.type_as(self.weight)
+ return self.weight * x
+
+
+class T5Attention(nn.Module):
+
+ def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
+ assert dim_attn % num_heads == 0
+ super(T5Attention, self).__init__()
+ self.dim = dim
+ self.dim_attn = dim_attn
+ self.num_heads = num_heads
+ self.head_dim = dim_attn // num_heads
+
+ # layers
+ self.q = nn.Linear(dim, dim_attn, bias=False)
+ self.k = nn.Linear(dim, dim_attn, bias=False)
+ self.v = nn.Linear(dim, dim_attn, bias=False)
+ self.o = nn.Linear(dim_attn, dim, bias=False)
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, x, context=None, mask=None, pos_bias=None):
+ """
+ x: [B, L1, C].
+ context: [B, L2, C] or None.
+ mask: [B, L2] or [B, L1, L2] or None.
+ """
+ # check inputs
+ context = x if context is None else context
+ b, n, c = x.size(0), self.num_heads, self.head_dim
+
+ # compute query, key, value
+ q = self.q(x).view(b, -1, n, c)
+ k = self.k(context).view(b, -1, n, c)
+ v = self.v(context).view(b, -1, n, c)
+
+ # attention bias
+ attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
+ if pos_bias is not None:
+ attn_bias += pos_bias
+ if mask is not None:
+ assert mask.ndim in [2, 3]
+ mask = mask.view(b, 1, 1,
+ -1) if mask.ndim == 2 else mask.unsqueeze(1)
+ attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
+
+ # compute attention (T5 does not use scaling)
+ attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias
+ attn = F.softmax(attn.float(), dim=-1).type_as(attn)
+ x = torch.einsum('bnij,bjnc->binc', attn, v)
+
+ # output
+ x = x.reshape(b, -1, n * c)
+ x = self.o(x)
+ x = self.dropout(x)
+ return x
+
+
+class T5FeedForward(nn.Module):
+
+ def __init__(self, dim, dim_ffn, dropout=0.1):
+ super(T5FeedForward, self).__init__()
+ self.dim = dim
+ self.dim_ffn = dim_ffn
+
+ # layers
+ self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
+ self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
+ self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, x):
+ x = self.fc1(x) * self.gate(x)
+ x = self.dropout(x)
+ x = self.fc2(x)
+ x = self.dropout(x)
+ return x
+
+
+class T5SelfAttention(nn.Module):
+
+ def __init__(self,
+ dim,
+ dim_attn,
+ dim_ffn,
+ num_heads,
+ num_buckets,
+ shared_pos=True,
+ dropout=0.1):
+ super(T5SelfAttention, self).__init__()
+ self.dim = dim
+ self.dim_attn = dim_attn
+ self.dim_ffn = dim_ffn
+ self.num_heads = num_heads
+ self.num_buckets = num_buckets
+ self.shared_pos = shared_pos
+
+ # layers
+ self.norm1 = T5LayerNorm(dim)
+ self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
+ self.norm2 = T5LayerNorm(dim)
+ self.ffn = T5FeedForward(dim, dim_ffn, dropout)
+ self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
+ num_buckets, num_heads, bidirectional=True)
+
+ def forward(self, x, mask=None, pos_bias=None):
+ e = pos_bias if self.shared_pos else self.pos_embedding(
+ x.size(1), x.size(1))
+ x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
+ x = fp16_clamp(x + self.ffn(self.norm2(x)))
+ return x
+
+
+class T5CrossAttention(nn.Module):
+
+ def __init__(self,
+ dim,
+ dim_attn,
+ dim_ffn,
+ num_heads,
+ num_buckets,
+ shared_pos=True,
+ dropout=0.1):
+ super(T5CrossAttention, self).__init__()
+ self.dim = dim
+ self.dim_attn = dim_attn
+ self.dim_ffn = dim_ffn
+ self.num_heads = num_heads
+ self.num_buckets = num_buckets
+ self.shared_pos = shared_pos
+
+ # layers
+ self.norm1 = T5LayerNorm(dim)
+ self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
+ self.norm2 = T5LayerNorm(dim)
+ self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
+ self.norm3 = T5LayerNorm(dim)
+ self.ffn = T5FeedForward(dim, dim_ffn, dropout)
+ self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
+ num_buckets, num_heads, bidirectional=False)
+
+ def forward(self,
+ x,
+ mask=None,
+ encoder_states=None,
+ encoder_mask=None,
+ pos_bias=None):
+ e = pos_bias if self.shared_pos else self.pos_embedding(
+ x.size(1), x.size(1))
+ x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
+ x = fp16_clamp(x + self.cross_attn(
+ self.norm2(x), context=encoder_states, mask=encoder_mask))
+ x = fp16_clamp(x + self.ffn(self.norm3(x)))
+ return x
+
+
+class T5RelativeEmbedding(nn.Module):
+
+ def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
+ super(T5RelativeEmbedding, self).__init__()
+ self.num_buckets = num_buckets
+ self.num_heads = num_heads
+ self.bidirectional = bidirectional
+ self.max_dist = max_dist
+
+ # layers
+ self.embedding = nn.Embedding(num_buckets, num_heads)
+
+ def forward(self, lq, lk):
+ device = self.embedding.weight.device
+ # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
+ # torch.arange(lq).unsqueeze(1).to(device)
+ rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \
+ torch.arange(lq, device=device).unsqueeze(1)
+ rel_pos = self._relative_position_bucket(rel_pos)
+ rel_pos_embeds = self.embedding(rel_pos)
+ rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(
+ 0) # [1, N, Lq, Lk]
+ return rel_pos_embeds.contiguous()
+
+ def _relative_position_bucket(self, rel_pos):
+ # preprocess
+ if self.bidirectional:
+ num_buckets = self.num_buckets // 2
+ rel_buckets = (rel_pos > 0).long() * num_buckets
+ rel_pos = torch.abs(rel_pos)
+ else:
+ num_buckets = self.num_buckets
+ rel_buckets = 0
+ rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
+
+ # embeddings for small and large positions
+ max_exact = num_buckets // 2
+ rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) /
+ math.log(self.max_dist / max_exact) *
+ (num_buckets - max_exact)).long()
+ rel_pos_large = torch.min(
+ rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
+ rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
+ return rel_buckets
+
+
+class T5Encoder(nn.Module):
+
+ def __init__(self,
+ vocab,
+ dim,
+ dim_attn,
+ dim_ffn,
+ num_heads,
+ num_layers,
+ num_buckets,
+ shared_pos=True,
+ dropout=0.1):
+ super(T5Encoder, self).__init__()
+ self.dim = dim
+ self.dim_attn = dim_attn
+ self.dim_ffn = dim_ffn
+ self.num_heads = num_heads
+ self.num_layers = num_layers
+ self.num_buckets = num_buckets
+ self.shared_pos = shared_pos
+
+ # layers
+ self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
+ else nn.Embedding(vocab, dim)
+ self.pos_embedding = T5RelativeEmbedding(
+ num_buckets, num_heads, bidirectional=True) if shared_pos else None
+ self.dropout = nn.Dropout(dropout)
+ self.blocks = nn.ModuleList([
+ T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
+ shared_pos, dropout) for _ in range(num_layers)
+ ])
+ self.norm = T5LayerNorm(dim)
+
+ # initialize weights
+ self.apply(init_weights)
+
+ def forward(self, ids, mask=None):
+ x = self.token_embedding(ids)
+ x = self.dropout(x)
+ e = self.pos_embedding(x.size(1),
+ x.size(1)) if self.shared_pos else None
+ for block in self.blocks:
+ x = block(x, mask, pos_bias=e)
+ x = self.norm(x)
+ x = self.dropout(x)
+ return x
+
+
+class T5Decoder(nn.Module):
+
+ def __init__(self,
+ vocab,
+ dim,
+ dim_attn,
+ dim_ffn,
+ num_heads,
+ num_layers,
+ num_buckets,
+ shared_pos=True,
+ dropout=0.1):
+ super(T5Decoder, self).__init__()
+ self.dim = dim
+ self.dim_attn = dim_attn
+ self.dim_ffn = dim_ffn
+ self.num_heads = num_heads
+ self.num_layers = num_layers
+ self.num_buckets = num_buckets
+ self.shared_pos = shared_pos
+
+ # layers
+ self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
+ else nn.Embedding(vocab, dim)
+ self.pos_embedding = T5RelativeEmbedding(
+ num_buckets, num_heads, bidirectional=False) if shared_pos else None
+ self.dropout = nn.Dropout(dropout)
+ self.blocks = nn.ModuleList([
+ T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
+ shared_pos, dropout) for _ in range(num_layers)
+ ])
+ self.norm = T5LayerNorm(dim)
+
+ # initialize weights
+ self.apply(init_weights)
+
+ def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None):
+ b, s = ids.size()
+
+ # causal mask
+ if mask is None:
+ mask = torch.tril(torch.ones(1, s, s).to(ids.device))
+ elif mask.ndim == 2:
+ mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1))
+
+ # layers
+ x = self.token_embedding(ids)
+ x = self.dropout(x)
+ e = self.pos_embedding(x.size(1),
+ x.size(1)) if self.shared_pos else None
+ for block in self.blocks:
+ x = block(x, mask, encoder_states, encoder_mask, pos_bias=e)
+ x = self.norm(x)
+ x = self.dropout(x)
+ return x
+
+
+class T5Model(nn.Module):
+
+ def __init__(self,
+ vocab_size,
+ dim,
+ dim_attn,
+ dim_ffn,
+ num_heads,
+ encoder_layers,
+ decoder_layers,
+ num_buckets,
+ shared_pos=True,
+ dropout=0.1):
+ super(T5Model, self).__init__()
+ self.vocab_size = vocab_size
+ self.dim = dim
+ self.dim_attn = dim_attn
+ self.dim_ffn = dim_ffn
+ self.num_heads = num_heads
+ self.encoder_layers = encoder_layers
+ self.decoder_layers = decoder_layers
+ self.num_buckets = num_buckets
+
+ # layers
+ self.token_embedding = nn.Embedding(vocab_size, dim)
+ self.encoder = T5Encoder(self.token_embedding, dim, dim_attn, dim_ffn,
+ num_heads, encoder_layers, num_buckets,
+ shared_pos, dropout)
+ self.decoder = T5Decoder(self.token_embedding, dim, dim_attn, dim_ffn,
+ num_heads, decoder_layers, num_buckets,
+ shared_pos, dropout)
+ self.head = nn.Linear(dim, vocab_size, bias=False)
+
+ # initialize weights
+ self.apply(init_weights)
+
+ def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
+ x = self.encoder(encoder_ids, encoder_mask)
+ x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask)
+ x = self.head(x)
+ return x
+
+
+def _t5(name,
+ encoder_only=False,
+ decoder_only=False,
+ return_tokenizer=False,
+ tokenizer_kwargs={},
+ dtype=torch.float32,
+ device='cpu',
+ **kwargs):
+ # sanity check
+ assert not (encoder_only and decoder_only)
+
+ # params
+ if encoder_only:
+ model_cls = T5Encoder
+ kwargs['vocab'] = kwargs.pop('vocab_size')
+ kwargs['num_layers'] = kwargs.pop('encoder_layers')
+ _ = kwargs.pop('decoder_layers')
+ elif decoder_only:
+ model_cls = T5Decoder
+ kwargs['vocab'] = kwargs.pop('vocab_size')
+ kwargs['num_layers'] = kwargs.pop('decoder_layers')
+ _ = kwargs.pop('encoder_layers')
+ else:
+ model_cls = T5Model
+
+ # init model
+ with torch.device(device):
+ model = model_cls(**kwargs)
+
+ # set device
+ model = model.to(dtype=dtype, device=device)
+
+ # init tokenizer
+ if return_tokenizer:
+ from .tokenizers import HuggingfaceTokenizer
+ tokenizer = HuggingfaceTokenizer(f'google/{name}', **tokenizer_kwargs)
+ return model, tokenizer
+ else:
+ return model
+
+
+def umt5_xxl(**kwargs):
+ cfg = dict(
+ vocab_size=256384,
+ dim=4096,
+ dim_attn=4096,
+ dim_ffn=10240,
+ num_heads=64,
+ encoder_layers=24,
+ decoder_layers=24,
+ num_buckets=32,
+ shared_pos=False,
+ dropout=0.1)
+ cfg.update(**kwargs)
+ return _t5('umt5-xxl', **cfg)
+
+
+class T5EncoderModel(nn.Module):
+
+ def __init__(
+ self,
+ text_len,
+ dtype=torch.bfloat16,
+ device=torch.cuda.current_device(),
+ checkpoint_path=None,
+ tokenizer_path=None,
+ shard_fn=None,
+ ):
+ super(T5EncoderModel, self).__init__()
+ self.text_len = text_len
+ self.dtype = dtype
+ self.device = device
+ self.checkpoint_path = checkpoint_path
+ self.tokenizer_path = tokenizer_path
+
+ with torch.device(device):
+ self.model = T5Encoder(
+ vocab=256384,
+ dim=4096,
+ dim_attn=4096,
+ dim_ffn=10240,
+ num_heads=64,
+ num_layers=24,
+ num_buckets=32,
+ shared_pos=False,
+ dropout=0.1
+ )
+ # set device
+ self.model = self.model.to(dtype=dtype, device=device).eval().requires_grad_(False)
+
+ logging.info(f'loading {checkpoint_path}')
+ if checkpoint_path is not None:
+ self.model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
+
+ if shard_fn is not None:
+ self.model = shard_fn(self.model, sync_module_states=False)
+ else:
+ self.model.to(self.device)
+ # init tokenizer
+ self.tokenizer = HuggingfaceTokenizer(
+ name=tokenizer_path, seq_len=text_len, clean='whitespace')
+
+ @torch.no_grad()
+ def __call__(self, texts, device):
+ ids, mask = self.tokenizer(
+ texts, return_mask=True, add_special_tokens=True)
+ ids = ids.to(device)
+ mask = mask.to(device)
+ seq_lens = mask.gt(0).sum(dim=1).long()
+ context = self.model(ids, mask)
+ return [u[:v] for u, v in zip(context, seq_lens)]
diff --git a/humo/models/wan_modules/tokenizers.py b/humo/models/wan_modules/tokenizers.py
new file mode 100644
index 0000000000000000000000000000000000000000..329c2add418c49c5df1c589c45dd124a59caafc7
--- /dev/null
+++ b/humo/models/wan_modules/tokenizers.py
@@ -0,0 +1,82 @@
+# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
+import html
+import string
+
+import ftfy
+import regex as re
+from transformers import AutoTokenizer
+
+__all__ = ['HuggingfaceTokenizer']
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r'\s+', ' ', text)
+ text = text.strip()
+ return text
+
+
+def canonicalize(text, keep_punctuation_exact_string=None):
+ text = text.replace('_', ' ')
+ if keep_punctuation_exact_string:
+ text = keep_punctuation_exact_string.join(
+ part.translate(str.maketrans('', '', string.punctuation))
+ for part in text.split(keep_punctuation_exact_string))
+ else:
+ text = text.translate(str.maketrans('', '', string.punctuation))
+ text = text.lower()
+ text = re.sub(r'\s+', ' ', text)
+ return text.strip()
+
+
+class HuggingfaceTokenizer:
+
+ def __init__(self, name, seq_len=None, clean=None, **kwargs):
+ assert clean in (None, 'whitespace', 'lower', 'canonicalize')
+ self.name = name
+ self.seq_len = seq_len
+ self.clean = clean
+
+ # init tokenizer
+ self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
+ self.vocab_size = self.tokenizer.vocab_size
+
+ def __call__(self, sequence, **kwargs):
+ return_mask = kwargs.pop('return_mask', False)
+
+ # arguments
+ _kwargs = {'return_tensors': 'pt'}
+ if self.seq_len is not None:
+ _kwargs.update({
+ 'padding': 'max_length',
+ 'truncation': True,
+ 'max_length': self.seq_len
+ })
+ _kwargs.update(**kwargs)
+
+ # tokenization
+ if isinstance(sequence, str):
+ sequence = [sequence]
+ if self.clean:
+ sequence = [self._clean(u) for u in sequence]
+ ids = self.tokenizer(sequence, **_kwargs)
+
+ # output
+ if return_mask:
+ return ids.input_ids, ids.attention_mask
+ else:
+ return ids.input_ids
+
+ def _clean(self, text):
+ if self.clean == 'whitespace':
+ text = whitespace_clean(basic_clean(text))
+ elif self.clean == 'lower':
+ text = whitespace_clean(basic_clean(text)).lower()
+ elif self.clean == 'canonicalize':
+ text = canonicalize(basic_clean(text))
+ return text
diff --git a/humo/models/wan_modules/vae.py b/humo/models/wan_modules/vae.py
new file mode 100644
index 0000000000000000000000000000000000000000..33b0ba82f5ce451fbecd8a0deaae19ffa8ca4b38
--- /dev/null
+++ b/humo/models/wan_modules/vae.py
@@ -0,0 +1,666 @@
+# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
+import logging
+
+import torch
+import torch.cuda.amp as amp
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+
+__all__ = [
+ 'WanVAE',
+]
+
+CACHE_T = 2
+
+
+class CausalConv3d(nn.Conv3d):
+ """
+ Causal 3d convolusion.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._padding = (self.padding[2], self.padding[2], self.padding[1],
+ self.padding[1], 2 * self.padding[0], 0)
+ self.padding = (0, 0, 0)
+
+ def forward(self, x, cache_x=None):
+ padding = list(self._padding)
+ if cache_x is not None and self._padding[4] > 0:
+ cache_x = cache_x.to(x.device)
+ x = torch.cat([cache_x, x], dim=2)
+ padding[4] -= cache_x.shape[2]
+ x = F.pad(x, padding)
+
+ return super().forward(x)
+
+
+class RMS_norm(nn.Module):
+
+ def __init__(self, dim, channel_first=True, images=True, bias=False):
+ super().__init__()
+ broadcastable_dims = (1, 1, 1) if not images else (1, 1)
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
+
+ self.channel_first = channel_first
+ self.scale = dim**0.5
+ self.gamma = nn.Parameter(torch.ones(shape))
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
+
+ def forward(self, x):
+ return F.normalize(
+ x, dim=(1 if self.channel_first else
+ -1)) * self.scale * self.gamma + self.bias
+
+
+class Upsample(nn.Upsample):
+
+ def forward(self, x):
+ """
+ Fix bfloat16 support for nearest neighbor interpolation.
+ """
+ return super().forward(x.float()).type_as(x)
+
+
+class Resample(nn.Module):
+
+ def __init__(self, dim, mode):
+ assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
+ 'downsample3d')
+ super().__init__()
+ self.dim = dim
+ self.mode = mode
+
+ # layers
+ if mode == 'upsample2d':
+ self.resample = nn.Sequential(
+ Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
+ nn.Conv2d(dim, dim // 2, 3, padding=1))
+ elif mode == 'upsample3d':
+ self.resample = nn.Sequential(
+ Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
+ nn.Conv2d(dim, dim // 2, 3, padding=1))
+ self.time_conv = CausalConv3d(
+ dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
+
+ elif mode == 'downsample2d':
+ self.resample = nn.Sequential(
+ nn.ZeroPad2d((0, 1, 0, 1)),
+ nn.Conv2d(dim, dim, 3, stride=(2, 2)))
+ elif mode == 'downsample3d':
+ self.resample = nn.Sequential(
+ nn.ZeroPad2d((0, 1, 0, 1)),
+ nn.Conv2d(dim, dim, 3, stride=(2, 2)))
+ self.time_conv = CausalConv3d(
+ dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
+
+ else:
+ self.resample = nn.Identity()
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ b, c, t, h, w = x.size()
+ if self.mode == 'upsample3d':
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ if feat_cache[idx] is None:
+ feat_cache[idx] = 'Rep'
+ feat_idx[0] += 1
+ else:
+
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[
+ idx] is not None and feat_cache[idx] != 'Rep':
+ # cache last frame of last two chunk
+ cache_x = torch.cat([
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
+ cache_x.device), cache_x
+ ],
+ dim=2)
+ if cache_x.shape[2] < 2 and feat_cache[
+ idx] is not None and feat_cache[idx] == 'Rep':
+ cache_x = torch.cat([
+ torch.zeros_like(cache_x).to(cache_x.device),
+ cache_x
+ ],
+ dim=2)
+ if feat_cache[idx] == 'Rep':
+ x = self.time_conv(x)
+ else:
+ x = self.time_conv(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+
+ x = x.reshape(b, 2, c, t, h, w)
+ x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
+ 3)
+ x = x.reshape(b, c, t * 2, h, w)
+ t = x.shape[2]
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
+ x = self.resample(x)
+ x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
+
+ if self.mode == 'downsample3d':
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ if feat_cache[idx] is None:
+ feat_cache[idx] = x.clone()
+ feat_idx[0] += 1
+ else:
+
+ cache_x = x[:, :, -1:, :, :].clone()
+ # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
+ # # cache last frame of last two chunk
+ # cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
+
+ x = self.time_conv(
+ torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ return x
+
+ def init_weight(self, conv):
+ conv_weight = conv.weight
+ nn.init.zeros_(conv_weight)
+ c1, c2, t, h, w = conv_weight.size()
+ one_matrix = torch.eye(c1, c2)
+ init_matrix = one_matrix
+ nn.init.zeros_(conv_weight)
+ #conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
+ conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5
+ conv.weight.data.copy_(conv_weight)
+ nn.init.zeros_(conv.bias.data)
+
+ def init_weight2(self, conv):
+ conv_weight = conv.weight.data
+ nn.init.zeros_(conv_weight)
+ c1, c2, t, h, w = conv_weight.size()
+ init_matrix = torch.eye(c1 // 2, c2)
+ #init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
+ conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
+ conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
+ conv.weight.data.copy_(conv_weight)
+ nn.init.zeros_(conv.bias.data)
+
+
+class ResidualBlock(nn.Module):
+
+ def __init__(self, in_dim, out_dim, dropout=0.0):
+ super().__init__()
+ self.in_dim = in_dim
+ self.out_dim = out_dim
+
+ # layers
+ self.residual = nn.Sequential(
+ RMS_norm(in_dim, images=False), nn.SiLU(),
+ CausalConv3d(in_dim, out_dim, 3, padding=1),
+ RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
+ CausalConv3d(out_dim, out_dim, 3, padding=1))
+ self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
+ if in_dim != out_dim else nn.Identity()
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ h = self.shortcut(x)
+ for layer in self.residual:
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ # cache last frame of last two chunk
+ cache_x = torch.cat([
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
+ cache_x.device), cache_x
+ ],
+ dim=2)
+ x = layer(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = layer(x)
+ return x + h
+
+
+class AttentionBlock(nn.Module):
+ """
+ Causal self-attention with a single head.
+ """
+
+ def __init__(self, dim):
+ super().__init__()
+ self.dim = dim
+
+ # layers
+ self.norm = RMS_norm(dim)
+ self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
+ self.proj = nn.Conv2d(dim, dim, 1)
+
+ # zero out the last layer params
+ nn.init.zeros_(self.proj.weight)
+
+ def forward(self, x):
+ identity = x
+ b, c, t, h, w = x.size()
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
+ x = self.norm(x)
+ # compute query, key, value
+ q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3,
+ -1).permute(0, 1, 3,
+ 2).contiguous().chunk(
+ 3, dim=-1)
+
+ # apply attention
+ x = F.scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ )
+ x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
+
+ # output
+ x = self.proj(x)
+ x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
+ return x + identity
+
+
+class Encoder3d(nn.Module):
+
+ def __init__(self,
+ dim=128,
+ z_dim=4,
+ dim_mult=[1, 2, 4, 4],
+ num_res_blocks=2,
+ attn_scales=[],
+ temperal_downsample=[True, True, False],
+ dropout=0.0):
+ super().__init__()
+ self.dim = dim
+ self.z_dim = z_dim
+ self.dim_mult = dim_mult
+ self.num_res_blocks = num_res_blocks
+ self.attn_scales = attn_scales
+ self.temperal_downsample = temperal_downsample
+
+ # dimensions
+ dims = [dim * u for u in [1] + dim_mult]
+ scale = 1.0
+
+ # init block
+ self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
+
+ # downsample blocks
+ downsamples = []
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
+ # residual (+attention) blocks
+ for _ in range(num_res_blocks):
+ downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
+ if scale in attn_scales:
+ downsamples.append(AttentionBlock(out_dim))
+ in_dim = out_dim
+
+ # downsample block
+ if i != len(dim_mult) - 1:
+ mode = 'downsample3d' if temperal_downsample[
+ i] else 'downsample2d'
+ downsamples.append(Resample(out_dim, mode=mode))
+ scale /= 2.0
+ self.downsamples = nn.Sequential(*downsamples)
+
+ # middle blocks
+ self.middle = nn.Sequential(
+ ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim),
+ ResidualBlock(out_dim, out_dim, dropout))
+
+ # output blocks
+ self.head = nn.Sequential(
+ RMS_norm(out_dim, images=False), nn.SiLU(),
+ CausalConv3d(out_dim, z_dim, 3, padding=1))
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ # cache last frame of last two chunk
+ cache_x = torch.cat([
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
+ cache_x.device), cache_x
+ ],
+ dim=2)
+ x = self.conv1(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = self.conv1(x)
+
+ ## downsamples
+ for layer in self.downsamples:
+ if feat_cache is not None:
+ x = layer(x, feat_cache, feat_idx)
+ else:
+ x = layer(x)
+
+ ## middle
+ for layer in self.middle:
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
+ x = layer(x, feat_cache, feat_idx)
+ else:
+ x = layer(x)
+
+ ## head
+ for layer in self.head:
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ # cache last frame of last two chunk
+ cache_x = torch.cat([
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
+ cache_x.device), cache_x
+ ],
+ dim=2)
+ x = layer(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = layer(x)
+ return x
+
+
+class Decoder3d(nn.Module):
+
+ def __init__(self,
+ dim=128,
+ z_dim=4,
+ dim_mult=[1, 2, 4, 4],
+ num_res_blocks=2,
+ attn_scales=[],
+ temperal_upsample=[False, True, True],
+ dropout=0.0):
+ super().__init__()
+ self.dim = dim
+ self.z_dim = z_dim
+ self.dim_mult = dim_mult
+ self.num_res_blocks = num_res_blocks
+ self.attn_scales = attn_scales
+ self.temperal_upsample = temperal_upsample
+
+ # dimensions
+ dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
+ scale = 1.0 / 2**(len(dim_mult) - 2)
+
+ # init block
+ self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
+
+ # middle blocks
+ self.middle = nn.Sequential(
+ ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
+ ResidualBlock(dims[0], dims[0], dropout))
+
+ # upsample blocks
+ upsamples = []
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
+ # residual (+attention) blocks
+ if i == 1 or i == 2 or i == 3:
+ in_dim = in_dim // 2
+ for _ in range(num_res_blocks + 1):
+ upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
+ if scale in attn_scales:
+ upsamples.append(AttentionBlock(out_dim))
+ in_dim = out_dim
+
+ # upsample block
+ if i != len(dim_mult) - 1:
+ mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
+ upsamples.append(Resample(out_dim, mode=mode))
+ scale *= 2.0
+ self.upsamples = nn.Sequential(*upsamples)
+
+ # output blocks
+ self.head = nn.Sequential(
+ RMS_norm(out_dim, images=False), nn.SiLU(),
+ CausalConv3d(out_dim, 3, 3, padding=1))
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ ## conv1
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ # cache last frame of last two chunk
+ cache_x = torch.cat([
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
+ cache_x.device), cache_x
+ ],
+ dim=2)
+ x = self.conv1(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = self.conv1(x)
+
+ ## middle
+ for layer in self.middle:
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
+ x = layer(x, feat_cache, feat_idx)
+ else:
+ x = layer(x)
+
+ ## upsamples
+ for layer in self.upsamples:
+ if feat_cache is not None:
+ x = layer(x, feat_cache, feat_idx)
+ else:
+ x = layer(x)
+
+ ## head
+ for layer in self.head:
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ # cache last frame of last two chunk
+ cache_x = torch.cat([
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
+ cache_x.device), cache_x
+ ],
+ dim=2)
+ x = layer(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = layer(x)
+ return x
+
+
+def count_conv3d(model):
+ count = 0
+ for m in model.modules():
+ if isinstance(m, CausalConv3d):
+ count += 1
+ return count
+
+
+class WanVAE_(nn.Module):
+
+ def __init__(self,
+ dim=128,
+ z_dim=4,
+ dim_mult=[1, 2, 4, 4],
+ num_res_blocks=2,
+ attn_scales=[],
+ temperal_downsample=[True, True, False],
+ dropout=0.0):
+ super().__init__()
+ self.dim = dim
+ self.z_dim = z_dim
+ self.dim_mult = dim_mult
+ self.num_res_blocks = num_res_blocks
+ self.attn_scales = attn_scales
+ self.temperal_downsample = temperal_downsample
+ self.temperal_upsample = temperal_downsample[::-1]
+
+ # modules
+ self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
+ attn_scales, self.temperal_downsample, dropout)
+ self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
+ self.conv2 = CausalConv3d(z_dim, z_dim, 1)
+ self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
+ attn_scales, self.temperal_upsample, dropout)
+
+ def forward(self, x):
+ mu, log_var = self.encode(x)
+ z = self.reparameterize(mu, log_var)
+ x_recon = self.decode(z)
+ return x_recon, mu, log_var
+
+ def encode(self, x, scale):
+ self.clear_cache()
+ ## cache
+ t = x.shape[2]
+ iter_ = 1 + (t - 1) // 4
+ ## 对encode输入的x,按时间拆分为1、4、4、4....
+ for i in range(iter_):
+ self._enc_conv_idx = [0]
+ if i == 0:
+ out = self.encoder(
+ x[:, :, :1, :, :],
+ feat_cache=self._enc_feat_map,
+ feat_idx=self._enc_conv_idx)
+ else:
+ out_ = self.encoder(
+ x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
+ feat_cache=self._enc_feat_map,
+ feat_idx=self._enc_conv_idx)
+ out = torch.cat([out, out_], 2)
+ mu, log_var = self.conv1(out).chunk(2, dim=1)
+ if isinstance(scale[0], torch.Tensor):
+ mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
+ 1, self.z_dim, 1, 1, 1)
+ else:
+ mu = (mu - scale[0]) * scale[1]
+ self.clear_cache()
+ return mu
+
+ def decode(self, z, scale):
+ self.clear_cache()
+ # z: [b,c,t,h,w]
+ if isinstance(scale[0], torch.Tensor):
+ z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
+ 1, self.z_dim, 1, 1, 1)
+ else:
+ z = z / scale[1] + scale[0]
+ iter_ = z.shape[2]
+ x = self.conv2(z)
+ for i in range(iter_):
+ self._conv_idx = [0]
+ if i == 0:
+ out = self.decoder(
+ x[:, :, i:i + 1, :, :],
+ feat_cache=self._feat_map,
+ feat_idx=self._conv_idx)
+ else:
+ out_ = self.decoder(
+ x[:, :, i:i + 1, :, :],
+ feat_cache=self._feat_map,
+ feat_idx=self._conv_idx)
+ out = torch.cat([out, out_], 2)
+ self.clear_cache()
+ return out
+
+ def reparameterize(self, mu, log_var):
+ std = torch.exp(0.5 * log_var)
+ eps = torch.randn_like(std)
+ return eps * std + mu
+
+ def sample(self, imgs, deterministic=False):
+ mu, log_var = self.encode(imgs)
+ if deterministic:
+ return mu
+ std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
+ return mu + std * torch.randn_like(std)
+
+ def clear_cache(self):
+ self._conv_num = count_conv3d(self.decoder)
+ self._conv_idx = [0]
+ self._feat_map = [None] * self._conv_num
+ #cache encode
+ self._enc_conv_num = count_conv3d(self.encoder)
+ self._enc_conv_idx = [0]
+ self._enc_feat_map = [None] * self._enc_conv_num
+
+
+def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs):
+ """
+ Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
+ """
+ # params
+ cfg = dict(
+ dim=96,
+ z_dim=z_dim,
+ dim_mult=[1, 2, 4, 4],
+ num_res_blocks=2,
+ attn_scales=[],
+ temperal_downsample=[False, True, True],
+ dropout=0.0)
+ cfg.update(**kwargs)
+
+ # init model
+ # with torch.device('meta'):
+ model = WanVAE_(**cfg)
+
+ # load checkpoint
+ logging.info(f'loading {pretrained_path}')
+ if pretrained_path is not None:
+ model.load_state_dict(torch.load(pretrained_path, map_location=device), assign=True)
+
+ return model
+
+
+class WanVAE:
+
+ def __init__(self,
+ z_dim=16,
+ vae_pth=None,
+ dtype=torch.float,
+ device="cuda"):
+ self.dtype = dtype
+ self.device = device
+
+ mean = [
+ -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
+ 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
+ ]
+ std = [
+ 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
+ 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
+ ]
+ self.mean = torch.tensor(mean, dtype=dtype, device=device)
+ self.std = torch.tensor(std, dtype=dtype, device=device)
+ self.scale = [self.mean, 1.0 / self.std]
+
+ # init model
+ self.model = _video_vae(
+ pretrained_path=vae_pth,
+ z_dim=z_dim,
+ ).eval().requires_grad_(False).to(device)
+
+ @torch.no_grad()
+ def encode(self, videos, device):
+ """
+ videos: A list of videos each with shape [C, T, H, W].
+ """
+
+ with torch.amp.autocast('cuda', dtype=self.dtype):
+ return [
+ self.model.encode(u.unsqueeze(0).to(device,self.dtype), self.scale).float().squeeze(0)
+ for u in videos
+ ]
+
+ @torch.no_grad()
+ def decode(self, zs):
+ with torch.amp.autocast('cuda', dtype=self.dtype):
+ return [
+ self.model.decode(u.unsqueeze(0),
+ self.scale).float().clamp_(-1, 1).squeeze(0)
+ for u in zs
+ ]
diff --git a/humo/models/wan_modules/xlm_roberta.py b/humo/models/wan_modules/xlm_roberta.py
new file mode 100644
index 0000000000000000000000000000000000000000..34858de961e1033ad120c13b1a0342f84f4907f6
--- /dev/null
+++ b/humo/models/wan_modules/xlm_roberta.py
@@ -0,0 +1,170 @@
+# Modified from transformers.models.xlm_roberta.modeling_xlm_roberta
+# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+__all__ = ['XLMRoberta', 'xlm_roberta_large']
+
+
+class SelfAttention(nn.Module):
+
+ def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
+ assert dim % num_heads == 0
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.eps = eps
+
+ # layers
+ self.q = nn.Linear(dim, dim)
+ self.k = nn.Linear(dim, dim)
+ self.v = nn.Linear(dim, dim)
+ self.o = nn.Linear(dim, dim)
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, x, mask):
+ """
+ x: [B, L, C].
+ """
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
+
+ # compute query, key, value
+ q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
+ k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
+ v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
+
+ # compute attention
+ p = self.dropout.p if self.training else 0.0
+ x = F.scaled_dot_product_attention(q, k, v, mask, p)
+ x = x.permute(0, 2, 1, 3).reshape(b, s, c)
+
+ # output
+ x = self.o(x)
+ x = self.dropout(x)
+ return x
+
+
+class AttentionBlock(nn.Module):
+
+ def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.post_norm = post_norm
+ self.eps = eps
+
+ # layers
+ self.attn = SelfAttention(dim, num_heads, dropout, eps)
+ self.norm1 = nn.LayerNorm(dim, eps=eps)
+ self.ffn = nn.Sequential(
+ nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim),
+ nn.Dropout(dropout))
+ self.norm2 = nn.LayerNorm(dim, eps=eps)
+
+ def forward(self, x, mask):
+ if self.post_norm:
+ x = self.norm1(x + self.attn(x, mask))
+ x = self.norm2(x + self.ffn(x))
+ else:
+ x = x + self.attn(self.norm1(x), mask)
+ x = x + self.ffn(self.norm2(x))
+ return x
+
+
+class XLMRoberta(nn.Module):
+ """
+ XLMRobertaModel with no pooler and no LM head.
+ """
+
+ def __init__(self,
+ vocab_size=250002,
+ max_seq_len=514,
+ type_size=1,
+ pad_id=1,
+ dim=1024,
+ num_heads=16,
+ num_layers=24,
+ post_norm=True,
+ dropout=0.1,
+ eps=1e-5):
+ super().__init__()
+ self.vocab_size = vocab_size
+ self.max_seq_len = max_seq_len
+ self.type_size = type_size
+ self.pad_id = pad_id
+ self.dim = dim
+ self.num_heads = num_heads
+ self.num_layers = num_layers
+ self.post_norm = post_norm
+ self.eps = eps
+
+ # embeddings
+ self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)
+ self.type_embedding = nn.Embedding(type_size, dim)
+ self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)
+ self.dropout = nn.Dropout(dropout)
+
+ # blocks
+ self.blocks = nn.ModuleList([
+ AttentionBlock(dim, num_heads, post_norm, dropout, eps)
+ for _ in range(num_layers)
+ ])
+
+ # norm layer
+ self.norm = nn.LayerNorm(dim, eps=eps)
+
+ def forward(self, ids):
+ """
+ ids: [B, L] of torch.LongTensor.
+ """
+ b, s = ids.shape
+ mask = ids.ne(self.pad_id).long()
+
+ # embeddings
+ x = self.token_embedding(ids) + \
+ self.type_embedding(torch.zeros_like(ids)) + \
+ self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
+ if self.post_norm:
+ x = self.norm(x)
+ x = self.dropout(x)
+
+ # blocks
+ mask = torch.where(
+ mask.view(b, 1, 1, s).gt(0), 0.0,
+ torch.finfo(x.dtype).min)
+ for block in self.blocks:
+ x = block(x, mask)
+
+ # output
+ if not self.post_norm:
+ x = self.norm(x)
+ return x
+
+
+def xlm_roberta_large(pretrained=False,
+ return_tokenizer=False,
+ device='cpu',
+ **kwargs):
+ """
+ XLMRobertaLarge adapted from Huggingface.
+ """
+ # params
+ cfg = dict(
+ vocab_size=250002,
+ max_seq_len=514,
+ type_size=1,
+ pad_id=1,
+ dim=1024,
+ num_heads=16,
+ num_layers=24,
+ post_norm=True,
+ dropout=0.1,
+ eps=1e-5)
+ cfg.update(**kwargs)
+
+ # init a model on device
+ with torch.device(device):
+ model = XLMRoberta(**cfg)
+ return model
diff --git a/humo/utils/audio_processor_whisper.py b/humo/utils/audio_processor_whisper.py
new file mode 100644
index 0000000000000000000000000000000000000000..557bf49c3abaa477405dbd6313c8b884a64ec4d0
--- /dev/null
+++ b/humo/utils/audio_processor_whisper.py
@@ -0,0 +1,173 @@
+# pylint: disable=C0301
+'''
+This module contains the AudioProcessor class and related functions for processing audio data.
+It utilizes various libraries and models to perform tasks such as preprocessing, feature extraction,
+and audio separation. The class is initialized with configuration parameters and can process
+audio files using the provided models.
+'''
+import os
+import subprocess
+
+import librosa
+import numpy as np
+import torch
+from audio_separator.separator import Separator
+from transformers import WhisperModel, AutoFeatureExtractor
+import torch.nn.functional as F
+
+
+def linear_interpolation_fps(features, input_fps, output_fps, output_len=None):
+ features = features.transpose(1, 2) # [1, C, T]
+ seq_len = features.shape[2] / float(input_fps)
+ if output_len is None:
+ output_len = int(seq_len * output_fps)
+ output_features = F.interpolate(features, size=output_len, align_corners=True, mode='linear')
+ return output_features.transpose(1, 2)
+
+
+def resample_audio(input_audio_file: str, output_audio_file: str, sample_rate: int):
+ p = subprocess.Popen([
+ "ffmpeg", "-y", "-v", "error", "-i", input_audio_file, "-ar", str(sample_rate), output_audio_file
+ ])
+ ret = p.wait()
+ assert ret == 0, "Resample audio failed!"
+ return output_audio_file
+
+class AudioProcessor:
+ """
+ AudioProcessor is a class that handles the processing of audio files.
+ It takes care of preprocessing the audio files, extracting features
+ using wav2vec models, and separating audio signals if needed.
+
+ :param sample_rate: Sampling rate of the audio file
+ :param fps: Frames per second for the extracted features
+ :param wav2vec_model_path: Path to the wav2vec model
+ :param only_last_features: Whether to only use the last features
+ :param audio_separator_model_path: Path to the audio separator model
+ :param audio_separator_model_name: Name of the audio separator model
+ :param cache_dir: Directory to cache the intermediate results
+ :param device: Device to run the processing on
+ """
+ def __init__(
+ self,
+ sample_rate,
+ fps,
+ wav2vec_model_path,
+ wav2vec_feature_type,
+ audio_separator_model_path:str=None,
+ audio_separator_model_name:str=None,
+ cache_dir:str='',
+ device="cuda:0",
+ ) -> None:
+ self.sample_rate = sample_rate
+ self.fps = fps
+ self.device = device
+
+ self.whisper = WhisperModel.from_pretrained(wav2vec_model_path).to(device).eval()
+ self.whisper.requires_grad_(False)
+ self.feature_extractor = AutoFeatureExtractor.from_pretrained(wav2vec_model_path)
+
+ if audio_separator_model_name is not None:
+ try:
+ os.makedirs(cache_dir, exist_ok=True)
+ except OSError as _:
+ print("Fail to create the output cache dir.")
+ self.audio_separator = Separator(
+ output_dir=cache_dir,
+ output_single_stem="vocals",
+ model_file_dir=audio_separator_model_path,
+ )
+ self.audio_separator.load_model(audio_separator_model_name)
+ assert self.audio_separator.model_instance is not None, "Fail to load audio separate model."
+ else:
+ self.audio_separator=None
+ print("Use audio directly without vocals seperator.")
+
+
+ def get_audio_feature(self, audio_path):
+ audio_input, sampling_rate = librosa.load(audio_path, sr=16000)
+ assert sampling_rate == 16000
+
+ audio_features = []
+ window = 750*640
+ for i in range(0, len(audio_input), window):
+ audio_feature = self.feature_extractor(audio_input[i:i+window],
+ sampling_rate=sampling_rate,
+ return_tensors="pt",
+ ).input_features
+ audio_features.append(audio_feature)
+ audio_features = torch.cat(audio_features, dim=-1)
+ return audio_features, len(audio_input) // 640
+
+
+ def preprocess(self, audio_path: str):
+ audio_input, audio_len = self.get_audio_feature(audio_path)
+ audio_feature = audio_input.to(self.whisper.device).float()
+ window = 3000
+ audio_prompts = []
+ for i in range(0, audio_feature.shape[-1], window):
+ audio_prompt = self.whisper.encoder(audio_feature[:,:,i:i+window], output_hidden_states=True).hidden_states
+ audio_prompt = torch.stack(audio_prompt, dim=2)
+ audio_prompts.append(audio_prompt)
+
+ audio_prompts = torch.cat(audio_prompts, dim=1)
+ audio_prompts = audio_prompts[:,:audio_len*2]
+
+ audio_emb = self.audio_emb_enc(audio_prompts, wav_enc_type="whisper")
+
+ return audio_emb, audio_emb.shape[0]
+
+ def audio_emb_enc(self, audio_emb, wav_enc_type="whisper"):
+ if wav_enc_type == "wav2vec":
+ feat_merge = audio_emb
+ elif wav_enc_type == "whisper":
+ # [1, T, 33, 1280]
+ feat0 = linear_interpolation_fps(audio_emb[:, :, 0: 8].mean(dim=2), 50, 25)
+ feat1 = linear_interpolation_fps(audio_emb[:, :, 8: 16].mean(dim=2), 50, 25)
+ feat2 = linear_interpolation_fps(audio_emb[:, :, 16: 24].mean(dim=2), 50, 25)
+ feat3 = linear_interpolation_fps(audio_emb[:, :, 24: 32].mean(dim=2), 50, 25)
+ feat4 = linear_interpolation_fps(audio_emb[:, :, 32], 50, 25)
+ feat_merge = torch.stack([feat0, feat1, feat2, feat3, feat4], dim=2)[0] # [T, 5, 1280]
+ else:
+ raise ValueError(f"Unsupported wav_enc_type: {wav_enc_type}")
+
+ return feat_merge
+
+ def get_audio_emb_window(self, audio_emb, frame_num, frame0_idx, audio_shift=2):
+ zero_audio_embed = torch.zeros((audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device)
+ zero_audio_embed_3 = torch.zeros((3, audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device) # device=audio_emb.device
+ iter_ = 1 + (frame_num - 1) // 4
+ audio_emb_wind = []
+ for lt_i in range(iter_):
+ if lt_i == 0: # latent_i
+ # 提取第一帧VAElatent,audio左侧补0,标识出
+ st = frame0_idx + lt_i - 2
+ ed = frame0_idx + lt_i + 3
+ wind_feat = torch.stack([
+ audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed
+ for i in range(st, ed)
+ ], dim=0) # [5, 13, 768]
+ wind_feat = torch.cat((zero_audio_embed_3, wind_feat), dim=0) # [8, 13, 768]
+ else:
+ st = frame0_idx + 1 + 4 * (lt_i - 1) - audio_shift
+ ed = frame0_idx + 1 + 4 * lt_i + audio_shift
+ wind_feat = torch.stack([
+ audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed
+ for i in range(st, ed)
+ ], dim=0) # [8, 13, 768]
+ audio_emb_wind.append(wind_feat)
+ audio_emb_wind = torch.stack(audio_emb_wind, dim=0) # [iter_, 8, 13, 768]
+
+ return audio_emb_wind, ed - audio_shift
+
+ def close(self):
+ """
+ TODO: to be implemented
+ """
+ return self
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, _exc_type, _exc_val, _exc_tb):
+ self.close()
diff --git a/humo/utils/wav2vec.py b/humo/utils/wav2vec.py
new file mode 100644
index 0000000000000000000000000000000000000000..7088391abfa1e5684bd36330b79ff6a4570c733b
--- /dev/null
+++ b/humo/utils/wav2vec.py
@@ -0,0 +1,218 @@
+# pylint: disable=R0901
+# src/models/wav2vec.py
+
+"""
+This module defines the Wav2Vec model, which is a pre-trained model for speech recognition and understanding.
+It inherits from the Wav2Vec2Model class in the transformers library and provides additional functionalities
+such as feature extraction and encoding.
+
+Classes:
+ Wav2VecModel: Inherits from Wav2Vec2Model and adds additional methods for feature extraction and encoding.
+
+Functions:
+ linear_interpolation: Interpolates the features based on the sequence length.
+"""
+
+import torch.nn.functional as F
+from transformers import Wav2Vec2Model
+from transformers.modeling_outputs import BaseModelOutput
+
+
+class Wav2VecModel(Wav2Vec2Model):
+ """
+ Wav2VecModel is a custom model class that extends the Wav2Vec2Model class from the transformers library.
+ It inherits all the functionality of the Wav2Vec2Model and adds additional methods for feature extraction and encoding.
+ ...
+
+ Attributes:
+ base_model (Wav2Vec2Model): The base Wav2Vec2Model object.
+
+ Methods:
+ forward(input_values, seq_len, attention_mask=None, mask_time_indices=None
+ , output_attentions=None, output_hidden_states=None, return_dict=None):
+ Forward pass of the Wav2VecModel.
+ It takes input_values, seq_len, and other optional parameters as input and returns the output of the base model.
+
+ feature_extract(input_values, seq_len):
+ Extracts features from the input_values using the base model.
+
+ encode(extract_features, attention_mask=None, mask_time_indices=None, output_attentions=None, output_hidden_states=None, return_dict=None):
+ Encodes the extracted features using the base model and returns the encoded features.
+ """
+ def forward(
+ self,
+ input_values,
+ seq_len,
+ attention_mask=None,
+ mask_time_indices=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ """
+ Forward pass of the Wav2Vec model.
+
+ Args:
+ self: The instance of the model.
+ input_values: The input values (waveform) to the model.
+ seq_len: The sequence length of the input values.
+ attention_mask: Attention mask to be used for the model.
+ mask_time_indices: Mask indices to be used for the model.
+ output_attentions: If set to True, returns attentions.
+ output_hidden_states: If set to True, returns hidden states.
+ return_dict: If set to True, returns a BaseModelOutput instead of a tuple.
+
+ Returns:
+ The output of the Wav2Vec model.
+ """
+ self.config.output_attentions = True
+
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ extract_features = self.feature_extractor(input_values)
+ extract_features = extract_features.transpose(1, 2)
+ extract_features = linear_interpolation(extract_features, seq_len=seq_len)
+
+ if attention_mask is not None:
+ # compute reduced attention_mask corresponding to feature vectors
+ attention_mask = self._get_feature_vector_attention_mask(
+ extract_features.shape[1], attention_mask, add_adapter=False
+ )
+
+ hidden_states, extract_features = self.feature_projection(extract_features)
+ hidden_states = self._mask_hidden_states(
+ hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
+ )
+
+ encoder_outputs = self.encoder(
+ hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = encoder_outputs[0]
+
+ if self.adapter is not None:
+ hidden_states = self.adapter(hidden_states)
+
+ if not return_dict:
+ return (hidden_states, ) + encoder_outputs[1:]
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+ def feature_extract(
+ self,
+ input_values,
+ seq_len,
+ ):
+ """
+ Extracts features from the input values and returns the extracted features.
+
+ Parameters:
+ input_values (torch.Tensor): The input values to be processed.
+ seq_len (torch.Tensor): The sequence lengths of the input values.
+
+ Returns:
+ extracted_features (torch.Tensor): The extracted features from the input values.
+ """
+ extract_features = self.feature_extractor(input_values)
+ extract_features = extract_features.transpose(1, 2)
+ extract_features = linear_interpolation(extract_features, seq_len=seq_len)
+
+ return extract_features
+
+ def encode(
+ self,
+ extract_features,
+ attention_mask=None,
+ mask_time_indices=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ """
+ Encodes the input features into the output space.
+
+ Args:
+ extract_features (torch.Tensor): The extracted features from the audio signal.
+ attention_mask (torch.Tensor, optional): Attention mask to be used for padding.
+ mask_time_indices (torch.Tensor, optional): Masked indices for the time dimension.
+ output_attentions (bool, optional): If set to True, returns the attention weights.
+ output_hidden_states (bool, optional): If set to True, returns all hidden states.
+ return_dict (bool, optional): If set to True, returns a BaseModelOutput instead of the tuple.
+
+ Returns:
+ The encoded output features.
+ """
+ self.config.output_attentions = True
+
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if attention_mask is not None:
+ # compute reduced attention_mask corresponding to feature vectors
+ attention_mask = self._get_feature_vector_attention_mask(
+ extract_features.shape[1], attention_mask, add_adapter=False
+ )
+
+ hidden_states, extract_features = self.feature_projection(extract_features)
+ hidden_states = self._mask_hidden_states(
+ hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
+ )
+
+ encoder_outputs = self.encoder(
+ hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = encoder_outputs[0]
+
+ if self.adapter is not None:
+ hidden_states = self.adapter(hidden_states)
+
+ if not return_dict:
+ return (hidden_states, ) + encoder_outputs[1:]
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+def linear_interpolation(features, seq_len):
+ """
+ Transpose the features to interpolate linearly.
+
+ Args:
+ features (torch.Tensor): The extracted features to be interpolated.
+ seq_len (torch.Tensor): The sequence lengths of the features.
+
+ Returns:
+ torch.Tensor: The interpolated features.
+ """
+ features = features.transpose(1, 2)
+ output_features = F.interpolate(features, size=seq_len, align_corners=True, mode='linear')
+ return output_features.transpose(1, 2)
+
+
+def linear_interpolation_fps(features, input_fps, output_fps, output_len=None):
+ features = features.transpose(1, 2) # [1, C, T]
+ seq_len = features.shape[2] / float(input_fps)
+ if output_len is None:
+ output_len = int(seq_len * output_fps)
+ output_features = F.interpolate(features, size=output_len, align_corners=True, mode='linear')
+ return output_features.transpose(1, 2)
\ No newline at end of file
diff --git a/main.py b/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..e06a4e85c6290fb63503decc5f3e8a724e836a81
--- /dev/null
+++ b/main.py
@@ -0,0 +1,28 @@
+# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Inference codes adapted from [SeedVR]
+# https://github.com/ByteDance-Seed/SeedVR/blob/main/projects/inference_seedvr2_7b.py
+
+from sys import argv
+import sys
+
+path_to_insert = "humo"
+if path_to_insert not in sys.path:
+ sys.path.insert(0, path_to_insert)
+
+from common.config import load_config, create_object
+
+# Load config.
+config = load_config(argv[1], argv[2:])
+
+runner = create_object(config)
+runner.entrypoint()
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e4705239063aed5fc9e42fba911182af3b81e09c
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,12 @@
+mediapy
+diffusers
+transformers
+torchao
+imageio
+numpy==1.26.4
+ftfy
+audio-separator==0.24.1
+onnxruntime
+omegaconf
+moviepy==1.0.3
+kernels
diff --git a/scripts/infer_ta.sh b/scripts/infer_ta.sh
new file mode 100644
index 0000000000000000000000000000000000000000..89be8106aaa2e23c7fc8ac1a57b2ff3da50c1541
--- /dev/null
+++ b/scripts/infer_ta.sh
@@ -0,0 +1,15 @@
+#!/bin/bash
+
+torchrun --node_rank=0 --nproc_per_node=8 --nnodes=1 \
+ --rdzv_endpoint=127.0.0.1:12345 \
+ --rdzv_conf=timeout=900,join_timeout=900,read_timeout=900 \
+ main.py humo/configs/inference/generate.yaml \
+ generation.frames=97 \
+ generation.scale_a=5.5 \
+ generation.scale_t=5.0 \
+ generation.mode=TA \
+ generation.height=720 \
+ generation.width=1280 \
+ diffusion.timesteps.sampling.steps=50 \
+ generation.positive_prompt=./examples/test_case.json \
+ generation.output.dir=./output
diff --git a/scripts/infer_ta_1_7B.sh b/scripts/infer_ta_1_7B.sh
new file mode 100644
index 0000000000000000000000000000000000000000..d87331640b06800aa4fbb2a22a87620656010629
--- /dev/null
+++ b/scripts/infer_ta_1_7B.sh
@@ -0,0 +1,49 @@
+#!/bin/bash
+
+# Run on a single GPU
+CUDA_VISIBLE_DEVICES=0 torchrun --node_rank=0 --nproc_per_node=1 --nnodes=1 \
+ --rdzv_endpoint=127.0.0.1:12345 \
+ --rdzv_conf=timeout=900,join_timeout=900,read_timeout=900 \
+ main.py humo/configs/inference/generate_1_7B.yaml \
+ dit.sp_size=1 \
+ generation.frames=97 \
+ generation.scale_t=7.0 \
+ generation.scale_a=7.5 \
+ generation.mode=TA \
+ generation.height=480 \
+ generation.width=832 \
+ diffusion.timesteps.sampling.steps=50 \
+ generation.positive_prompt=./examples/test_case.json \
+ generation.output.dir=./output
+
+# # Run on 2 GPUs
+# CUDA_VISIBLE_DEVICES=0,1 torchrun --node_rank=0 --nproc_per_node=2 --nnodes=1 \
+# --rdzv_endpoint=127.0.0.1:12345 \
+# --rdzv_conf=timeout=900,join_timeout=900,read_timeout=900 \
+# main.py humo/configs/inference/generate_1_7B.yaml \
+# dit.sp_size=2 \
+# generation.frames=97 \
+# generation.scale_t=7.0 \
+# generation.scale_a=7.5 \
+# generation.mode=TA \
+# generation.height=480 \
+# generation.width=832 \
+# diffusion.timesteps.sampling.steps=50 \
+# generation.positive_prompt=./examples/test_case.json \
+# generation.output.dir=./output
+
+# # Run on 4 GPUs
+# CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --node_rank=0 --nproc_per_node=4 --nnodes=1 \
+# --rdzv_endpoint=127.0.0.1:12345 \
+# --rdzv_conf=timeout=900,join_timeout=900,read_timeout=900 \
+# main.py humo/configs/inference/generate_1_7B.yaml \
+# dit.sp_size=4 \
+# generation.frames=97 \
+# generation.scale_t=7.0 \
+# generation.scale_a=7.5 \
+# generation.mode=TA \
+# generation.height=480 \
+# generation.width=832 \
+# diffusion.timesteps.sampling.steps=50 \
+# generation.positive_prompt=./examples/test_case.json \
+# generation.output.dir=./output
\ No newline at end of file
diff --git a/scripts/infer_tia.sh b/scripts/infer_tia.sh
new file mode 100644
index 0000000000000000000000000000000000000000..e628342db1058ebcb628ab61742bad878e747314
--- /dev/null
+++ b/scripts/infer_tia.sh
@@ -0,0 +1,15 @@
+#!/bin/bash
+
+torchrun --node_rank=0 --nproc_per_node=8 --nnodes=1 \
+ --rdzv_endpoint=127.0.0.1:12345 \
+ --rdzv_conf=timeout=900,join_timeout=900,read_timeout=900 \
+ main.py humo/configs/inference/generate.yaml \
+ generation.frames=97 \
+ generation.scale_a=5.5 \
+ generation.scale_t=5.0 \
+ generation.mode=TIA \
+ generation.height=720 \
+ generation.width=1280 \
+ diffusion.timesteps.sampling.steps=50 \
+ generation.positive_prompt=./examples/test_case.json \
+ generation.output.dir=./output
diff --git a/scripts/infer_tia_1_7B.sh b/scripts/infer_tia_1_7B.sh
new file mode 100644
index 0000000000000000000000000000000000000000..45e5e5ae9a12fc062fbb0d585c741eaa87f5b856
--- /dev/null
+++ b/scripts/infer_tia_1_7B.sh
@@ -0,0 +1,52 @@
+#!/bin/bash
+
+# Run on a single GPU
+CUDA_VISIBLE_DEVICES=0 torchrun --node_rank=0 --nproc_per_node=1 --nnodes=1 \
+ --rdzv_endpoint=127.0.0.1:12345 \
+ --rdzv_conf=timeout=900,join_timeout=900,read_timeout=900 \
+ main.py humo/configs/inference/generate_1_7B.yaml \
+ dit.sp_size=1 \
+ generation.frames=97 \
+ generation.scale_t=7.0 \
+ generation.scale_i=4.0 \
+ generation.scale_a=7.5 \
+ generation.mode=TIA \
+ generation.height=480 \
+ generation.width=832 \
+ diffusion.timesteps.sampling.steps=50 \
+ generation.positive_prompt=./examples/test_case.json \
+ generation.output.dir=./output
+
+# # Run on 2 GPUs
+# CUDA_VISIBLE_DEVICES=0,1 torchrun --node_rank=0 --nproc_per_node=2 --nnodes=1 \
+# --rdzv_endpoint=127.0.0.1:12345 \
+# --rdzv_conf=timeout=900,join_timeout=900,read_timeout=900 \
+# main.py humo/configs/inference/generate_1_7B.yaml \
+# dit.sp_size=2 \
+# generation.frames=97 \
+# generation.scale_t=7.0 \
+# generation.scale_i=4.0 \
+# generation.scale_a=7.5 \
+# generation.mode=TIA \
+# generation.height=480 \
+# generation.width=832 \
+# diffusion.timesteps.sampling.steps=50 \
+# generation.positive_prompt=./examples/test_case.json \
+# generation.output.dir=./output
+
+# # Run on 4 GPUs
+# CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --node_rank=0 --nproc_per_node=4 --nnodes=1 \
+# --rdzv_endpoint=127.0.0.1:12345 \
+# --rdzv_conf=timeout=900,join_timeout=900,read_timeout=900 \
+# main.py humo/configs/inference/generate_1_7B.yaml \
+# dit.sp_size=4 \
+# generation.frames=97 \
+# generation.scale_t=7.0 \
+# generation.scale_i=4.0 \
+# generation.scale_a=7.5 \
+# generation.mode=TIA \
+# generation.height=480 \
+# generation.width=832 \
+# diffusion.timesteps.sampling.steps=50 \
+# generation.positive_prompt=./examples/test_case.json \
+# generation.output.dir=./output