diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..c729f041a0c89076070ba075b8e8df90e5ad187b 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,14 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +assets/teaser.png filter=lfs diff=lfs merge=lfs -text +examples/amber.png filter=lfs diff=lfs merge=lfs -text +examples/armour.png filter=lfs diff=lfs merge=lfs -text +examples/art.wav filter=lfs diff=lfs merge=lfs -text +examples/chris.png filter=lfs diff=lfs merge=lfs -text +examples/dream.mp3 filter=lfs diff=lfs merge=lfs -text +examples/fictional.wav filter=lfs diff=lfs merge=lfs -text +examples/fight.wav filter=lfs diff=lfs merge=lfs -text +examples/jacket.png filter=lfs diff=lfs merge=lfs -text +examples/naomi.png filter=lfs diff=lfs merge=lfs -text +examples/vangogh.jpg filter=lfs diff=lfs merge=lfs -text diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..62ae395d2b41b3006f5f79d5a7539bad424b21a6 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2025 Bytedance + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md index c9bc0df86bbd5f3622eb509cf68dca0e9d3289cf..1c090487f44deadc989c3613db1527af22827d6a 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,13 @@ ---- -title: HuMo Local -emoji: 🌍 -colorFrom: green -colorTo: red -sdk: gradio -sdk_version: 5.49.1 -app_file: app.py -pinned: false ---- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +--- +title: HuMo [Local] +emoji: 👩‍🦱 +colorFrom: purple +colorTo: gray +sdk: gradio +sdk_version: 5.47.2 +app_file: app.py +pinned: false +short_description: Reference based video generation +--- + +Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference \ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..8f8dfff7e5164c3e2b60234481c9e6a2527b47bc --- /dev/null +++ b/app.py @@ -0,0 +1,435 @@ +import spaces +import gradio as gr +import sys +import os +import subprocess +import uuid +import shutil + + + +from huggingface_hub import snapshot_download, list_repo_files, hf_hub_download +import importlib, site + + +# Re-discover all .pth/.egg-link files +for sitedir in site.getsitepackages(): + site.addsitedir(sitedir) + +# Clear caches so importlib will pick up new modules +importlib.invalidate_caches() + +def sh(cmd): subprocess.check_call(cmd, shell=True) + +flash_attention_installed = False + +try: + flash_attention_wheel = hf_hub_download( + repo_id="alexnasa/flash-attn-3", + repo_type="model", + filename="128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl", + ) + + sh(f"pip install {flash_attention_wheel}") + print("Attempting to download and install FlashAttention wheel...") + # sh("pip install flash-attn") + sh("pip install --no-build-isolation transformer_engine-2.5.0+f05f12c9-cp310-cp310-linux_x86_64.whl") + + # tell Python to re-scan site-packages now that the egg-link exists + import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.invalidate_caches() + + flash_attention_installed = True + +except Exception as e: + print(f"⚠️ Could not install FlashAttention: {e}") + print("Continuing without FlashAttention...") + +try: + te_wheel = hf_hub_download( + repo_id="alexnasa/transformer_engine_wheels", + repo_type="model", + filename="transformer_engine-2.5.0+f05f12c9-cp310-cp310-linux_x86_64.whl", + ) + + sh(f"pip install {te_wheel}") + print("Attempting to download and install Transformer Engine wheel...") + + # tell Python to re-scan site-packages now that the egg-link exists + import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.invalidate_caches() + +except Exception as e: + print(f"⚠️ Could not install Transformer Engine : {e}") + print("Continuing without Transformer Engine ...") + +import torch +print(f"Torch version: {torch.__version__}") +print(f"FlashAttention available: {flash_attention_installed}") + +import tempfile +from pathlib import Path +from torch._inductor.runtime.runtime_utils import cache_dir as _inductor_cache_dir +from huggingface_hub import HfApi + + +snapshot_download(repo_id="bytedance-research/HuMo", local_dir="./weights/HuMo") +snapshot_download(repo_id="Wan-AI/Wan2.1-T2V-1.3B", local_dir="./weights/Wan2.1-T2V-1.3B") +snapshot_download(repo_id="openai/whisper-large-v3", local_dir="./weights/whisper-large-v3") + +os.environ["PROCESSED_RESULTS"] = f"{os.getcwd()}/proprocess_results" + +path_to_insert = "humo" +if path_to_insert not in sys.path: + sys.path.insert(0, path_to_insert) + +from common.config import load_config, create_object + +config = load_config( + "./humo/configs/inference/generate.yaml", + [ + "dit.sp_size=1", + "generation.frames=97", + "generation.scale_t=5.5", + "generation.scale_a=5.0", + "generation.mode=TIA", + "generation.height=480", + "generation.width=832", + ], +) +runner = create_object(config) + + +os.environ.setdefault("TORCHINDUCTOR_CACHE_DIR", f"{os.getcwd()}/torchinductor_space") # or another writable path + +def restore_inductor_cache_from_hub(repo_id: str, filename: str = "torch_compile_cache.zip", + path_in_repo: str = "inductor_cache", repo_type: str = "model", + hf_token: str | None = None): + cache_root = Path(_inductor_cache_dir()).resolve() + cache_root.mkdir(parents=True, exist_ok=True) + zip_path = hf_hub_download(repo_id=repo_id, filename=f"{path_in_repo}/{filename}", + repo_type=repo_type, token=hf_token) + shutil.unpack_archive(zip_path, extract_dir=str(cache_root)) + print(f"✓ Restored cache into {cache_root}") + + +# restore_inductor_cache_from_hub("alexnasa/humo-compiled") + + +def get_duration(prompt_text, steps, image_file, audio_file_path, tea_cache_l1_thresh, max_duration, session_id): + + return calculate_required_time(steps, max_duration) + +def calculate_required_time(steps, max_duration): + + warmup_s = 60 + + max_duration_duration_mapping = { + 1: 8, + 2: 8, + 3: 11, + 4: 20, + 5: 30, + } + each_step_s = max_duration_duration_mapping[max_duration] + duration_s = (each_step_s * steps) + warmup_s + + print(f'estimated duration:{duration_s}') + + return int(duration_s) + +def get_required_time_string(steps, max_duration): + + duration_s = calculate_required_time(steps, max_duration) + duration_m = duration_s / 60 + + return f"
⌚ Zero GPU Required: ~{duration_s}.0s ({duration_m:.1f} mins)
" + +def update_required_time(steps, max_duration): + + return get_required_time_string(steps, max_duration) + + +def generate_scene(prompt_text, steps, image_paths, audio_file_path, tea_cache_l1_thresh, max_duration = 2, session_id = None): + + print(image_paths) + prompt_text_check = (prompt_text or "").strip() + if not prompt_text_check: + raise gr.Error("Please enter a prompt.") + + if not audio_file_path and not image_paths: + raise gr.Error("Please provide a reference image or a lipsync audio.") + + return run_pipeline(prompt_text, steps, image_paths, audio_file_path, tea_cache_l1_thresh, max_duration, session_id) + + + +def upload_inductor_cache_to_hub( + repo_id: str, + path_in_repo: str = "inductor_cache", + repo_type: str = "model", # or "dataset" if you prefer + hf_token: str | None = None, +): + """ + Zips the current TorchInductor cache and uploads it to the given repo path. + Assumes the model was already run once with torch.compile() so the cache exists. + """ + + cache_dir = Path(_inductor_cache_dir()).resolve() + if not cache_dir.exists(): + raise FileNotFoundError(f"TorchInductor cache not found at {cache_dir}. " + "Run a compiled model once to populate it.") + + # Create a zip archive of the entire cache directory + with tempfile.TemporaryDirectory() as tmpdir: + archive_base = Path(tmpdir) / "torch_compile_cache" + archive_path = shutil.make_archive(str(archive_base), "zip", root_dir=str(cache_dir)) + archive_path = Path(archive_path) + + # Upload to Hub + api = HfApi(token=hf_token) + api.create_repo(repo_id=repo_id, repo_type=repo_type, exist_ok=True) + # Put each artifact under path_in_repo, including a tiny metadata stamp for traceability + # Upload the zip + dest_path = f"{path_in_repo}/{archive_path.name}" + api.upload_file( + path_or_fileobj=str(archive_path), + path_in_repo=dest_path, + repo_id=repo_id, + repo_type=repo_type, + ) + # Upload a small metadata file (optional but handy) + meta_txt = ( + f"pytorch={torch.__version__}\n" + f"inductor_cache_dir={cache_dir}\n" + f"cuda_available={torch.cuda.is_available()}\n" + f"cuda_device={torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'cpu'}\n" + ) + api.upload_file( + path_or_fileobj=meta_txt.encode(), + path_in_repo=f"{path_in_repo}/INDUCTOR_CACHE_METADATA.txt", + repo_id=repo_id, + repo_type=repo_type, + ) + + print("✔ Uploaded TorchInductor cache to the Hub.") + + +@spaces.GPU(duration=get_duration) +def run_pipeline(prompt_text, steps, image_paths, audio_file_path, tea_cache_l1_thresh = 0.0, max_duration = 2, session_id = None): + + if session_id is None: + session_id = uuid.uuid4().hex + + inference_mode = "TIA" + + # Validate inputs + prompt_text = (prompt_text or "").strip() + if not prompt_text: + raise gr.Error("Please enter a prompt.") + + if not audio_file_path and not image_paths: + raise gr.Error("Please provide a reference image or a lipsync audio.") + + if not audio_file_path: + inference_mode = "TI" + audio_path = None + else: + audio_path = audio_file_path if isinstance(audio_file_path, str) else getattr(audio_file_path, "name", str(audio_file_path)) + + if not image_paths: + inference_mode = "TA" + img_paths = None + else: + img_paths = [image_data[0] for image_data in image_paths] + + + # Prepare output + output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id) + os.makedirs(output_dir, exist_ok=True) + + # Random filename + filename = f"gen_{uuid.uuid4().hex[:10]}" + width, height = 832, 480 + + duration_frame_mapping = { + 1:25, + 2:45, + 3:70, + 4:97, + 5:129 + } + + # Run inference + runner.inference_loop( + prompt_text, + img_paths, + audio_path, + output_dir, + filename, + inference_mode, + width, + height, + steps, + frames = int(duration_frame_mapping[max_duration]), + tea_cache_l1_thresh = tea_cache_l1_thresh, + ) + + # Return resulting video path + video_path = os.path.join(output_dir, f"{filename}.mp4") + if os.path.exists(video_path): + + # upload_inductor_cache_to_hub("alexnasa/humo-compiled") + + return video_path + else: + candidates = [os.path.join(output_dir, f) for f in os.listdir(output_dir) if f.endswith(".mp4")] + if candidates: + return max(candidates, key=lambda p: os.path.getmtime(p)) + return None + +css = """ + #col-container { + margin: 0 auto; + width: 100%; + max-width: 720px; + } + """ + +def cleanup(request: gr.Request): + + sid = request.session_hash + if sid: + d1 = os.path.join(os.environ["PROCESSED_RESULTS"], sid) + shutil.rmtree(d1, ignore_errors=True) + +def start_session(request: gr.Request): + + return request.session_hash + +with gr.Blocks(css=css) as demo: + + session_state = gr.State() + demo.load(start_session, outputs=[session_state]) + + with gr.Sidebar(width=400): + + + gr.HTML( + """ +
+

+ HuMo – Human-Centric Video Generation via Collaborative Multi-Modal Conditioning +

+ + [Github] + +
+ """ + ) + + gr.Markdown("**REFERENCE IMAGES**") + + img_input = gr.Gallery( + show_label=False, + label="", + interactive=True, + rows=1, columns=3, object_fit="contain", height="280", + file_types=['image'] + ) + + gr.Markdown("**LIPSYNC AUDIO**") + + audio_input = gr.Audio( + sources=["upload"], + show_label=False, + type="filepath", + ) + + gr.Markdown("**SETTINGS**") + + default_steps = 10 + default_max_duration = 2 + + max_duration = gr.Slider(minimum=2, maximum=5, value=default_max_duration, step=1, label="Max Duration") + steps_input = gr.Slider(minimum=5, maximum=50, value=default_steps, step=5, label="Diffusion Steps") + tea_cache_l1_thresh = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.01, label="Cache", visible=False) + + + + with gr.Column(elem_id="col-container"): + + gr.HTML( + """ +
+ HF Space by: + + GitHub Repo + +
+ """ + ) + + video_output = gr.Video(show_label=False) + + gr.Markdown("

PROMPT

") + + prompt_tb = gr.Textbox( + show_label=False, + lines=5, + placeholder="Describe the scene and the person talking....", + ) + + gr.Markdown("") + time_required = gr.Markdown(get_required_time_string(default_steps, default_max_duration)) + run_btn = gr.Button("🎬 Action", variant="primary") + + gr.Examples( + examples=[ + + [ + "A handheld tracking shot follows a female warrior walking through a cave. Her determined eyes are locked straight ahead. She speaks with intensity.", + 5, + ["./examples/naomi.png"], + "./examples/dream.mp3", + ], + + [ + "A reddish-brown haired and bearded man sits pensively against swirling blue-and-white brushstrokes, dressed in a blue coat and dark waistcoat. The artistic backdrop and his thoughtful pose evoke a Post-Impressionist style in a studio-like setting.", + 10, + ["./examples/vangogh.jpg"], + "./examples/art.wav", + ], + + [ + "A handheld tracking shot follows a female through a science lab. Her determined eyes are locked straight ahead. The clip is in black and white and patchy as she is explaining something to someone standing opposite her", + 10, + ["./examples/naomi.png"], + "./examples/science.wav", + ], + + [ + "A woman with long, wavy dark hair looking at a person sitting opposite her whilst holding a book, wearing a leather jacket, long-sleeved jacket with a semi purple color one seen on a photo. Warm, window-like light bathes her figure, highlighting the outfit's elegant design and her graceful movements.", + 50, + ["./examples/amber.png", "./examples/jacket.png"], + "./examples/fictional.mp3", + ], + + ], + inputs=[prompt_tb, steps_input, img_input, audio_input], + outputs=[video_output], + fn=run_pipeline, + cache_examples=True, + ) + max_duration.change(update_required_time, [steps_input, max_duration], time_required) + steps_input.change(update_required_time, [steps_input, max_duration], time_required) + + run_btn.click( + fn=generate_scene, + inputs=[prompt_tb, steps_input, img_input, audio_input, tea_cache_l1_thresh, max_duration, session_state], + outputs=[video_output], + ) + + +if __name__ == "__main__": + demo.unload(cleanup) + demo.queue() + demo.launch(ssr_mode=False) diff --git a/assets/teaser.png b/assets/teaser.png new file mode 100644 index 0000000000000000000000000000000000000000..906c01b24258d9cbea17374f129eec02f68ada6c --- /dev/null +++ b/assets/teaser.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:722d29d27fb89a6e1ebebef233492f0c06c25b09a0bdb8e723ef567e778bcf34 +size 5832395 diff --git a/examples/amber.png b/examples/amber.png new file mode 100644 index 0000000000000000000000000000000000000000..be254a66a54f637325bd46a5b43c01f32770c58e --- /dev/null +++ b/examples/amber.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6ce1a891ea71b184eeb4bc768322006c4ecccc8e063b2a0afd38829c6e975f03 +size 2396702 diff --git a/examples/armour.png b/examples/armour.png new file mode 100644 index 0000000000000000000000000000000000000000..0fa22b591e19170c6fcb976340798541ff31c983 --- /dev/null +++ b/examples/armour.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:192dd4b1c80c9ddacb8678962b5a1c04855d44c9877810aba032761ce50052a2 +size 1790470 diff --git a/examples/art.wav b/examples/art.wav new file mode 100644 index 0000000000000000000000000000000000000000..f5ecf0684d1849be5c43d623e9249ee168a2f5f0 --- /dev/null +++ b/examples/art.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:72c75df8e93a107e262ea9b002a66e72d3c1cd2084bce1474a31d8afffd0b651 +size 114254 diff --git a/examples/chris.png b/examples/chris.png new file mode 100644 index 0000000000000000000000000000000000000000..47e1466a43fa607d689a0ca94c5f3c1956511b14 --- /dev/null +++ b/examples/chris.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a3100088e3247d8ecaf1ded2e2417e70d6ad34d24ce7f6e7551cb7fd24c91dcf +size 2053453 diff --git a/examples/dream.mp3 b/examples/dream.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..6dfc3d08efa1132d14863bc64781efc818e5c532 --- /dev/null +++ b/examples/dream.mp3 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:27248fd9e8f29bd60ccb1163b8df3c6f2630734f358aa3362ffe67e8148e0eb1 +size 108275 diff --git a/examples/fictional.wav b/examples/fictional.wav new file mode 100644 index 0000000000000000000000000000000000000000..27bb74c78eb069d23b132b6fa9de7461ad2db58b --- /dev/null +++ b/examples/fictional.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:31b550e6433ea44a0642dee90c326664ff4f568fec184170001f834597b3ad23 +size 167084 diff --git a/examples/fight.wav b/examples/fight.wav new file mode 100644 index 0000000000000000000000000000000000000000..5763aaaf7af9d8056a3ade8b25ca5e5e52c031be --- /dev/null +++ b/examples/fight.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8dbee86c85e992ac6d17820a3730bf753fc9bf5bac6b8a470f84b7e98a64221a +size 264782 diff --git a/examples/jacket.png b/examples/jacket.png new file mode 100644 index 0000000000000000000000000000000000000000..6530268bff18c6d841bb492f906ef500a716cbf4 --- /dev/null +++ b/examples/jacket.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e80a02659148e3364eaa46e46dd64a268f04c7f7eeed0e8f203b6b848738666a +size 2565494 diff --git a/examples/naomi.png b/examples/naomi.png new file mode 100644 index 0000000000000000000000000000000000000000..c9d117f403585c16157f3f06852f538458ff0a10 --- /dev/null +++ b/examples/naomi.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5666cd6253658e76695e8529b28d730c74f9e63a1afca07e47505e59b24e7656 +size 1177240 diff --git a/examples/science.wav b/examples/science.wav new file mode 100644 index 0000000000000000000000000000000000000000..6486ce5d79678c56b3d6dea9174e02b64a0dac8e Binary files /dev/null and b/examples/science.wav differ diff --git a/examples/vangogh.jpg b/examples/vangogh.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e5441c8aa1deac4b12023da5001f6034682c9e8c --- /dev/null +++ b/examples/vangogh.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1ae77da89271f32196ad9e8a915e20a7f71e9a84b78ecec3aa1dfcb4e4b39de1 +size 138650 diff --git a/humo/common/__init__.py b/humo/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/humo/common/config.py b/humo/common/config.py new file mode 100644 index 0000000000000000000000000000000000000000..ea65b5da331eb2acfc7c26d7146ea01cabbb0c9d --- /dev/null +++ b/humo/common/config.py @@ -0,0 +1,107 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Codes adapted from [SeedVR] +# https://github.com/ByteDance-Seed/SeedVR/blob/main/common/config.py + +""" +Configuration utility functions +""" + +import importlib +from typing import Any, Callable, List, Union +from omegaconf import DictConfig, ListConfig, OmegaConf + +OmegaConf.register_new_resolver("eval", eval) + + +def load_config(path: str, argv: List[str] = None) -> Union[DictConfig, ListConfig]: + """ + Load a configuration. Will resolve inheritance. + """ + config = OmegaConf.load(path) + if argv is not None: + config_argv = OmegaConf.from_dotlist(argv) + config = OmegaConf.merge(config, config_argv) + config = resolve_recursive(config, resolve_inheritance) + return config + + +def resolve_recursive( + config: Any, + resolver: Callable[[Union[DictConfig, ListConfig]], Union[DictConfig, ListConfig]], +) -> Any: + config = resolver(config) + if isinstance(config, DictConfig): + for k in config.keys(): + v = config.get(k) + if isinstance(v, (DictConfig, ListConfig)): + config[k] = resolve_recursive(v, resolver) + if isinstance(config, ListConfig): + for i in range(len(config)): + v = config.get(i) + if isinstance(v, (DictConfig, ListConfig)): + config[i] = resolve_recursive(v, resolver) + return config + + +def resolve_inheritance(config: Union[DictConfig, ListConfig]) -> Any: + """ + Recursively resolve inheritance if the config contains: + __inherit__: path/to/parent.yaml. + """ + if isinstance(config, DictConfig): + inherit = config.pop("__inherit__", None) + if inherit: + assert isinstance(inherit, str) + inherit = load_config(inherit) + if len(config.keys()) > 0: + config = OmegaConf.merge(inherit, config) + else: + config = inherit + return config + + +def import_item(path: str, name: str) -> Any: + """ + Import a python item. Example: import_item("path.to.file", "MyClass") -> MyClass + """ + return getattr(importlib.import_module(path), name) + + +def create_object(config: DictConfig) -> Any: + """ + Create an object from config. + The config is expected to contains the following: + __object__: + path: path.to.module + name: MyClass + args: as_config | as_params (default to as_config) + """ + item = import_item( + path=config.__object__.path, + name=config.__object__.name, + ) + args = config.__object__.get("args", "as_config") + if args == "as_config": + return item(config) + if args == "as_params": + config = OmegaConf.to_object(config) + config.pop("__object__") + return item(**config) + raise NotImplementedError(f"Unknown args type: {args}") + + +def create_dataset(path: str, *args, **kwargs) -> Any: + """ + Create a dataset. Requires the file to contain a "create_dataset" function. + """ + return import_item(path, "create_dataset")(*args, **kwargs) diff --git a/humo/common/distributed/__init__.py b/humo/common/distributed/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..21e58bc73d36cf0d6b9bdb818dd3f504c04ebbe4 --- /dev/null +++ b/humo/common/distributed/__init__.py @@ -0,0 +1,41 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Codes adapted from [SeedVR] +# https://github.com/ByteDance-Seed/SeedVR/tree/main/common/distributed + +""" +Distributed package. +""" + +from .basic import ( + barrier_if_distributed, + convert_to_ddp, + get_device, + get_global_rank, + get_local_rank, + get_world_size, + init_torch, + meta_param_init_fn, + meta_non_persistent_buffer_init_fn +) + +__all__ = [ + "barrier_if_distributed", + "convert_to_ddp", + "get_device", + "get_global_rank", + "get_local_rank", + "get_world_size", + "init_torch", + "meta_param_init_fn", + "meta_non_persistent_buffer_init_fn", +] diff --git a/humo/common/distributed/advanced.py b/humo/common/distributed/advanced.py new file mode 100644 index 0000000000000000000000000000000000000000..e349f5c6674028e932145af4b3279766969186e8 --- /dev/null +++ b/humo/common/distributed/advanced.py @@ -0,0 +1,484 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Codes adapted from [SeedVR] +# https://github.com/ByteDance-Seed/SeedVR/tree/main/common/distributed + +""" +Advanced distributed functions for sequence parallel. +""" + +import torch +from typing import Any, List, Optional, Tuple, Union +import torch.distributed as dist +from torch import Tensor + +from .basic import get_global_rank, get_world_size + + +_DATA_PARALLEL_GROUP = None +_SEQUENCE_PARALLEL_GROUP = None +_SEQUENCE_PARALLEL_CPU_GROUP = None + + +_CFG_PARALLEL_GROUP = None +_CFG_PARALLEL_CPU_GROUP = None + +def get_data_parallel_group() -> Optional[dist.ProcessGroup]: + """ + Get data parallel process group. + """ + return _DATA_PARALLEL_GROUP + + +def get_sequence_parallel_group() -> Optional[dist.ProcessGroup]: + """ + Get sequence parallel process group. + """ + return _SEQUENCE_PARALLEL_GROUP + + +def get_sequence_parallel_cpu_group() -> Optional[dist.ProcessGroup]: + """ + Get sequence parallel CPU process group. + """ + return _SEQUENCE_PARALLEL_CPU_GROUP + + +def get_data_parallel_rank() -> int: + """ + Get data parallel rank. + """ + group = get_data_parallel_group() + return dist.get_rank(group) if group else get_global_rank() + + +def get_data_parallel_world_size() -> int: + """ + Get data parallel world size. + """ + group = get_data_parallel_group() + return dist.get_world_size(group) if group else get_world_size() + + +def get_sequence_parallel_rank() -> int: + """ + Get sequence parallel rank. + """ + group = get_sequence_parallel_group() + return dist.get_rank(group) if group else 0 + + +def get_sequence_parallel_world_size() -> int: + """ + Get sequence parallel world size. + """ + group = get_sequence_parallel_group() + return dist.get_world_size(group) if group else 1 + + +def init_unified_parallel(unified_parallel_size): + global _SEQUENCE_PARALLEL_GROUP + global _SEQUENCE_PARALLEL_CPU_GROUP + + if unified_parallel_size == 1: + return + + assert dist.is_initialized() + world_size = dist.get_world_size() + rank = dist.get_rank() + assert world_size % unified_parallel_size == 0 + data_parallel_size = world_size // unified_parallel_size + + for i in range(data_parallel_size): + # build unified parallel group + start_rank = i * unified_parallel_size + end_rank = start_rank + unified_parallel_size + unified_parallel_ranks = range(start_rank, end_rank) + unified_parallel_group = dist.new_group(unified_parallel_ranks) + unified_parallel_cpu_group = dist.new_group(unified_parallel_ranks, backend="gloo") + if rank in unified_parallel_ranks: + _SEQUENCE_PARALLEL_GROUP = unified_parallel_group + _SEQUENCE_PARALLEL_CPU_GROUP = unified_parallel_cpu_group + + +def get_unified_parallel_group(): + global _SEQUENCE_PARALLEL_GROUP + return _SEQUENCE_PARALLEL_GROUP + + +def get_unified_parallel_cpu_group(): + global _SEQUENCE_PARALLEL_CPU_GROUP + return _SEQUENCE_PARALLEL_CPU_GROUP + + +def get_unified_parallel_rank(): + group = get_unified_parallel_group() + return dist.get_rank(group) if group else 0 + + +def get_unified_parallel_world_size(): + group = get_unified_parallel_group() + return dist.get_world_size(group) if group else 1 + + +def is_unified_parallel_initialized(): + group = get_unified_parallel_group() + return group is not None + + +def pad_tensor(x: Tensor, dim: int, padding_size: int): + shape = list(x.shape) + shape[dim] = padding_size + pad = torch.zeros(shape, dtype=x.dtype, device=x.device) + return torch.cat([x, pad], dim=dim) + + +class Slice(torch.autograd.Function): + @staticmethod + def forward(ctx: Any, group: dist.ProcessGroup, local_input: Tensor, dim: int, scale_grad: bool) -> Tensor: + ctx.group = group + ctx.rank = dist.get_rank(group) + seq_world_size = dist.get_world_size(group) + ctx.seq_world_size = seq_world_size + ctx.dim = dim + ctx.scale_grad = scale_grad + dim_size = local_input.shape[dim] + if not ctx.group: + return local_input + return local_input.split(dim_size // seq_world_size, dim=dim)[ctx.rank].contiguous() + + @staticmethod + def backward(ctx: Any, grad_output: Tensor) -> Tuple[None, Tensor, None]: + if not ctx.group: + return None, grad_output, None, None + dim_size = list(grad_output.size()) + split_size = dim_size[0] + dim_size[0] = dim_size[0] * ctx.seq_world_size + output = torch.empty(dim_size, dtype=grad_output.dtype, device=torch.cuda.current_device()) + dist.all_gather_into_tensor(output, grad_output, group=ctx.group) + if ctx.scale_grad: + output = output / ctx.seq_world_size + return (None, torch.cat(output.split(split_size), dim=ctx.dim), None, None) + + +def gather_outputs( + x: Tensor, + gather_dim: int, + padding_dim: Optional[int] = None, + unpad_dim_size: Optional[int] = None, + scale_grad=True, +): + """ + A func to gather the outputs for the model result in sequence parallel + """ + group = get_unified_parallel_group() + if not group: + return x + x = Gather.apply(group, x, gather_dim, scale_grad) + if padding_dim is not None: + x = unpadding_tensor_for_seqeunce_parallel(x, padding_dim, unpad_dim_size) + return x + + +def unpadding_tensor_for_seqeunce_parallel(x: Tensor, dim: int, unpadded_dim_size: int): + """ + A func to remove the padding part of the tensor based on its original shape + """ + group = get_unified_parallel_group() + if group is None: + return x + sp_world = get_unified_parallel_world_size() + if unpadded_dim_size % sp_world == 0: + return x + padding_size = sp_world - (unpadded_dim_size % sp_world) + assert (padding_size + unpadded_dim_size) % sp_world == 0 + return unpad_tensor(x, dim=dim, padding_size=padding_size) + + +def gather_seq_scatter_heads_qkv( + qkv_tensor: Tensor, + seq_dim: int, + unpadded_dim_size: Optional[int] = None, + restore_shape: bool = True, + async_op: bool = False, +): + """ + A func to sync splited qkv tensor + qkv_tensor: the tensor we want to do alltoall with. The last dim must + be the projection_idx, which we will split into 3 part. After + spliting, the gather idx will be projecttion_idx + 1 + seq_dim: gather_dim for all2all comm + restore_shape: if True, output will has the same shape length as input + """ + group = get_unified_parallel_group() + if not group: + return qkv_tensor + world = get_unified_parallel_world_size() + orig_shape = qkv_tensor.shape + scatter_dim = qkv_tensor.dim() + bef_all2all_shape = list(orig_shape) + qkv_proj_dim = bef_all2all_shape[-1] + bef_all2all_shape = bef_all2all_shape[:-1] + [3, qkv_proj_dim // 3] + qkv_tensor = qkv_tensor.view(bef_all2all_shape) + if async_op: + return SeqAllToAll.apply(group, qkv_tensor, scatter_dim, seq_dim, async_op) + else: + qkv_tensor = SeqAllToAll.apply(group, qkv_tensor, scatter_dim, seq_dim, async_op) + + if restore_shape: + out_shape = list(orig_shape) + out_shape[seq_dim] *= world + out_shape[-1] = qkv_proj_dim // world + qkv_tensor = qkv_tensor.view(out_shape) + + # remove padding + if unpadded_dim_size and unpadded_dim_size % world != 0: + padding_size = qkv_tensor.size(seq_dim) - unpadded_dim_size + qkv_tensor = unpad_tensor(qkv_tensor, seq_dim, padding_size) + + return qkv_tensor + + +def gather_seq_scatter_double_head( + qkv_tensor: Tensor, + seq_dim: int, + unpadded_dim_size: Optional[int] = None, + restore_shape: bool = True, + async_op: bool = False, +): + """ + A func to sync splited qkv tensor + qkv_tensor: the tensor we want to do alltoall with. The last dim must + be the projection_idx, which we will split into 3 part. After + spliting, the gather idx will be projecttion_idx + 1 + seq_dim: gather_dim for all2all comm + restore_shape: if True, output will has the same shape length as input + """ + qkv1_shape = qkv_tensor.shape + group = get_unified_parallel_group() + if not group: + return qkv_tensor + world = get_unified_parallel_world_size() + orig_shape = qkv_tensor.shape + scatter_dim = qkv_tensor.dim() + bef_all2all_shape = list(orig_shape) + qkv_proj_dim = bef_all2all_shape[-1] + bef_all2all_shape = bef_all2all_shape[:-1] + [2, qkv_proj_dim // 2] + qkv_tensor = qkv_tensor.view(bef_all2all_shape) + qkv2_shape = qkv_tensor.shape + if async_op: + return SeqAllToAll.apply(group, qkv_tensor, scatter_dim, seq_dim, async_op) + else: + qkv_tensor = SeqAllToAll.apply(group, qkv_tensor, scatter_dim, seq_dim, async_op) + qkv3_shape = qkv_tensor.shape + + if restore_shape: + out_shape = list(orig_shape) + out_shape[seq_dim] *= world + out_shape[-1] = qkv_proj_dim // world + qkv_tensor = qkv_tensor.view(out_shape) + qkv4_shape = qkv_tensor.shape + + # remove padding + if unpadded_dim_size and unpadded_dim_size % world != 0: + padding_size = qkv_tensor.size(seq_dim) - unpadded_dim_size + qkv_tensor = unpad_tensor(qkv_tensor, seq_dim, padding_size) + qkv5_shape = qkv_tensor.shape + + return qkv_tensor + + +class SeqAllToAll(torch.autograd.Function): + @staticmethod + def forward( + ctx: Any, + group: dist.ProcessGroup, + local_input: Tensor, + scatter_dim: int, + gather_dim: int, + async_op: bool, + ) -> Tensor: + ctx.group = group + ctx.scatter_dim = scatter_dim + ctx.gather_dim = gather_dim + ctx.async_op = async_op + return all_to_all_tensor(local_input, scatter_dim, gather_dim, group, async_op) + + @staticmethod + def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]: + if ctx.async_op: + input_t = torch.cat(grad_output[1:], dim=ctx.gather_dim).contiguous() + else: + input_t = grad_output[0] + return ( + None, + all_to_all_tensor(input_t, ctx.gather_dim, ctx.scatter_dim, ctx.group, False), + None, + None, + None, + None, + ) + + +def all_to_all_tensor( + x: Tensor, + scatter_dim: int, + gather_dim: int, + group: dist.ProcessGroup, + async_op: bool = False, +): + if scatter_dim <= 1 and gather_dim <= 1: + return _all_to_all_single(x, scatter_dim, gather_dim, group, async_op) + else: + return _all_to_all(x, scatter_dim, gather_dim, group, async_op) # 走这里 + + +def _all_to_all( + local_input: Tensor, + scatter_dim: int, + gather_dim: int, + group: dist.ProcessGroup, + async_op: bool = False, +): + seq_world_size = dist.get_world_size(group) + input_list = [t.contiguous() for t in torch.tensor_split(local_input, seq_world_size, scatter_dim)] + output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)] + comm = dist.all_to_all(output_list, input_list, group=group, async_op=async_op) + if async_op: + + def wait(): + comm.wait() + return torch.cat(output_list, dim=gather_dim).contiguous() + + return wait + return torch.cat(output_list, dim=gather_dim).contiguous() + + +def _all_to_all_single(x: Tensor, scatter_dim: int, gather_dim: int, group: dist.ProcessGroup, async_op: bool = False): + """ + A function to do all-to-all on the first two dim + """ + sp_world_size = dist.get_world_size(group) + assert scatter_dim <= 1, "scatter_dim must be 0 or 1 when using all_to_all_single!" + assert gather_dim <= 1, "gather_dim must be 0 or 1 when using all_to_all_single!" + if scatter_dim != 0: + gather_dim_bef = x.shape[gather_dim] + scatter_dim_bef = x.shape[scatter_dim] + x = ( + x.reshape([gather_dim_bef, sp_world_size, scatter_dim_bef // sp_world_size] + list(x.shape[2:])) + .transpose(0, 1) + .reshape([gather_dim_bef * sp_world_size, scatter_dim_bef // sp_world_size] + list(x.shape[2:])) + .contiguous() + ) + + output = torch.empty_like(x) + comm = dist.all_to_all_single(output, x.contiguous(), group=group, async_op=async_op) + + if async_op: + + def wait(): + comm.wait() + if scatter_dim == 0: + return torch.cat(output.split(x.size(0) // sp_world_size), dim=gather_dim) + else: + return output + + return wait + + if scatter_dim == 0: + output = torch.cat(output.split(x.size(0) // sp_world_size), dim=gather_dim) + return output + + +def gather_heads_scatter_seq(x: Tensor, head_dim: int, seq_dim: int) -> Tensor: + """ + A func to sync attention result with alltoall in sequence parallel + """ + group = get_unified_parallel_group() + if not group: + return x + dim_size = x.size(seq_dim) + sp_world = get_unified_parallel_world_size() + if dim_size % sp_world != 0: + padding_size = sp_world - (dim_size % sp_world) + x = pad_tensor(x, seq_dim, padding_size) + return SeqAllToAll.apply(group, x, seq_dim, head_dim, False) + + +def unpad_tensor(x: Tensor, dim: int, padding_size: int): + slc = [slice(None)] * len(x.shape) + slc[dim] = slice(0, -padding_size) + return x[slc] + + +class Gather(torch.autograd.Function): + @staticmethod + def forward( + ctx: Any, + group: dist.ProcessGroup, + local_input: Tensor, + dim: int, + grad_scale: Optional[bool] = False, + ) -> Tensor: + ctx.group = group + ctx.rank = dist.get_rank(group) + ctx.dim = dim + ctx.grad_scale = grad_scale + seq_world_size = dist.get_world_size(group) + ctx.seq_world_size = seq_world_size + dim_size = list(local_input.size()) + split_size = dim_size[0] + ctx.part_size = dim_size[dim] + dim_size[0] = dim_size[0] * seq_world_size + output = torch.empty(dim_size, dtype=local_input.dtype, device=torch.cuda.current_device()) + dist.all_gather_into_tensor(output, local_input.contiguous(), group=ctx.group) + return torch.cat(output.split(split_size), dim=dim) + + @staticmethod + def backward(ctx: Any, grad_output: Tensor) -> Tuple[None, Tensor]: + if ctx.grad_scale: + grad_output = grad_output * ctx.seq_world_size + return ( + None, + grad_output.split(ctx.part_size, dim=ctx.dim)[ctx.rank].contiguous(), + None, + None, + ) + + +def slice_tensor(tensor, dim, start, end): + indices = slice(start, end) + return tensor[(slice(None),) * dim + (indices,)] + + +def init_model_shard_cpu_group(sharding_strategy: str, device_mesh: Optional[Tuple] = None): + """ + Initialize CPU process group of model sharding. + """ + global _MODEL_SHARD_CPU_GROUP + assert dist.is_initialized() + world_size = dist.get_world_size() + rank = dist.get_rank() + if device_mesh is not None: + num_shards_per_group = device_mesh[1] + elif "HYBRID" in sharding_strategy: + num_shards_per_group = min(8, world_size) + else: + num_shards_per_group = world_size + num_groups = world_size // num_shards_per_group + for i in range(num_groups): + start_rank = i * num_shards_per_group + end_rank = (i + 1) * num_shards_per_group + ranks = range(start_rank, end_rank) + cpu_group = dist.new_group(ranks, backend="gloo") + if rank in ranks: + _MODEL_SHARD_CPU_GROUP = cpu_group diff --git a/humo/common/distributed/basic.py b/humo/common/distributed/basic.py new file mode 100644 index 0000000000000000000000000000000000000000..364dc73491a3342ab9898c51432180de8d3c3267 --- /dev/null +++ b/humo/common/distributed/basic.py @@ -0,0 +1,143 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Codes adapted from [SeedVR] +# https://github.com/ByteDance-Seed/SeedVR/tree/main/common/distributed + +""" +Distributed basic functions. +""" + +import os +import torch +from torch import nn +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel +from torch.distributed.fsdp._common_utils import _is_fsdp_flattened + + +def get_global_rank() -> int: + """ + Get the global rank, the global index of the GPU. + """ + return int(os.environ.get("RANK", "0")) + + +def get_local_rank() -> int: + """ + Get the local rank, the local index of the GPU. + """ + return int(os.environ.get("LOCAL_RANK", "0")) + + +def get_world_size() -> int: + """ + Get the world size, the total amount of GPUs. + """ + return int(os.environ.get("WORLD_SIZE", "1")) + + +def get_device() -> torch.device: + """ + Get current rank device. + """ + return torch.device("cuda", get_local_rank()) + + +def barrier_if_distributed(*args, **kwargs): + """ + Synchronizes all processes if under distributed context. + """ + if dist.is_initialized(): + return dist.barrier(*args, **kwargs) + + +def init_torch(cudnn_benchmark=True): + """ + Common PyTorch initialization configuration. + """ + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.backends.cudnn.benchmark = cudnn_benchmark + torch.cuda.set_device(get_local_rank()) + dist.init_process_group( + backend="nccl", + rank=get_global_rank(), + world_size=get_world_size(), + ) + + +def convert_to_ddp(module: torch.nn.Module, **kwargs) -> DistributedDataParallel: + return DistributedDataParallel( + module=module, + device_ids=[get_local_rank()], + output_device=get_local_rank(), + **kwargs, + ) + + +def meta_param_init_fn(module: nn.Module) -> None: + """ + Used for model inited onto meta device. + Init meta param/buffer with empty tensor. + We don't care numerical correctness in this func. + FSDP will sync param/buffer state from rank0 to the other ranks. + """ + + with torch.no_grad(): + for submodule in module.modules(): + for param_name, param in submodule.named_parameters(recurse=False): + if not _is_fsdp_flattened(param) and param.is_meta: + materialized_param = nn.Parameter(torch.empty_like(param, device="cpu")) + setattr(submodule, param_name, materialized_param) + for buffer_name, buffer in submodule.named_buffers(recurse=False): + if not _is_fsdp_flattened(buffer) and buffer.is_meta: + materialized_param = torch.empty_like(buffer, device="cpu") + setattr(submodule, buffer_name, materialized_param) + torch.cuda.empty_cache() + + +def meta_non_persistent_buffer_init_fn(module: nn.Module) -> nn.Module: + """ + Materialize meta device buffers that are not persistent in state_dict. + Handles special cases like RotaryEmbedding.freqs. + """ + with torch.no_grad(): + for submodule in module.modules(): + if hasattr(submodule, "freqs"): + freqs = getattr(submodule, "freqs") + if isinstance(freqs, torch.Tensor) and freqs.is_meta: + dim = submodule.dim + def rope_params(max_seq_len, dim, theta=10000): + assert dim % 2 == 0 + freqs = torch.outer( + torch.arange(max_seq_len), + 1.0 / torch.pow(theta, + torch.arange(0, dim, 2).to(torch.float32).div(dim))) + freqs = torch.polar(torch.ones_like(freqs), freqs) + return freqs + + dim = 5120 # 1536 + num_heads = 40 # 12 + # dim = 1536 + # num_heads = 12 + d = dim // num_heads + freqs_tensor = torch.cat([ + rope_params(1024, d - 4 * (d // 6)), + rope_params(1024, 2 * (d // 6)), + rope_params(1024, 2 * (d // 6)) + ], dim=1).to(dtype=torch.cfloat, device="cpu") + + setattr(submodule, "freqs", freqs_tensor) + print(f"Successfully materialized freqs for {submodule.__class__.__name__}") + + assert not any(b.is_meta for n, b in module.named_buffers()) + return module diff --git a/humo/common/logger.py b/humo/common/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..f14adffe7610d97051e97957a34f021e1bd2f6bc --- /dev/null +++ b/humo/common/logger.py @@ -0,0 +1,44 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Codes adapted from [SeedVR] +# https://github.com/ByteDance-Seed/SeedVR/blob/main/common/logger.py + +""" +Logging utility functions. +""" + +import logging +import sys +from typing import Optional + +from common.distributed import get_global_rank, get_local_rank, get_world_size + +_default_handler = logging.StreamHandler(sys.stdout) +_default_handler.setFormatter( + logging.Formatter( + "%(asctime)s " + + (f"[Rank:{get_global_rank()}]" if get_world_size() > 1 else "") + + (f"[LocalRank:{get_local_rank()}]" if get_world_size() > 1 else "") + + "[%(threadName).12s][%(name)s][%(levelname).5s] " + + "%(message)s" + ) +) + + +def get_logger(name: Optional[str] = None) -> logging.Logger: + """ + Get a logger. + """ + logger = logging.getLogger(name) + logger.addHandler(_default_handler) + logger.setLevel(logging.INFO) + return logger diff --git a/humo/configs/inference/generate.yaml b/humo/configs/inference/generate.yaml new file mode 100644 index 0000000000000000000000000000000000000000..aad930e312f613930e80e60047dd5669369c91c2 --- /dev/null +++ b/humo/configs/inference/generate.yaml @@ -0,0 +1,78 @@ +__object__: + path: humo.generate + name: Generator + +dit: + model: + __inherit__: humo/configs/models/Wan_14B_I2V.yaml + __object__: + path: humo.models.wan_modules.model_humo + name: WanModel + insert_audio: True + zero_vae_path: ./weights/HuMo/zero_vae_129frame.pt + zero_vae_720p_path: ./weights/HuMo/zero_vae_720p_161frame.pt + checkpoint_dir: ./weights/HuMo/HuMo-17B + compile: False + init_with_meta_device: True + gradient_checkpoint: True + fsdp: + sharding_strategy: _HYBRID_SHARD_ZERO2 + sp_size: 1 + dtype: bfloat16 + +vae: + checkpoint: ./weights/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth + vae_stride: [ 4, 8, 8 ] + scaling_factor: 0.9152 + compile: False + grouping: True + use_sample: False + dtype: bfloat16 + +text: + t5_checkpoint: ./weights/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth + t5_tokenizer: ./weights/Wan2.1-T2V-1.3B/google/umt5-xxl + dropout: 0.1 + dtype: bfloat16 + fsdp: + enabled: True + sharding_strategy: HYBRID_SHARD + +diffusion: + schedule: + type: lerp + T: 1000.0 + sampler: + type: euler + prediction_type: v_lerp + timesteps: + training: + type: logitnormal + loc: 0.0 + scale: 1.0 + sampling: + type: uniform_trailing + steps: 50 + shift: 5.0 + +audio: + vocal_separator: ./weights/HuMo/audio_separator/Kim_Vocal_2.onnx + wav2vec_model: ./weights/whisper-large-v3 + +generation: + mode: "TIA" # TA, TIA + extract_audio_feat: True + seed: 666666 + frames: 97 + fps: 25 + height: 480 # 720 480 + width: 832 # 1280 832 + batch_size: 1 + sequence_parallel: 8 + output: + dir: ./output + # positive_prompt: ./examples/test_case.json + sample_neg_prompt: '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走' + scale_a: 5.5 + scale_t: 5.0 + step_change: 980 \ No newline at end of file diff --git a/humo/configs/inference/generate_1_7B.yaml b/humo/configs/inference/generate_1_7B.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2a9d1a88ba457e67d01069fa34d45b8e0896a896 --- /dev/null +++ b/humo/configs/inference/generate_1_7B.yaml @@ -0,0 +1,76 @@ +__object__: + path: humo.generate_1_7B + name: Generator + +dit: + model: + __inherit__: humo/configs/models/Wan_1.3B.yaml + __object__: + path: humo.models.wan_modules.model_humo + name: WanModel + insert_audio: True + zero_vae_path: ./weights/HuMo/zero_vae_129frame.pt + zero_vae_720p_path: ./weights/HuMo/zero_vae_720p_161frame.pt + checkpoint_dir: ./weights/HuMo/HuMo-1.7B/ema.pth #./weights/HuMo/HuMo-17B + compile: False + init_with_meta_device: True + gradient_checkpoint: True + fsdp: + sharding_strategy: _HYBRID_SHARD_ZERO2 + sp_size: 1 + +vae: + checkpoint: ./weights/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth + vae_stride: [ 4, 8, 8 ] + scaling_factor: 0.9152 + compile: False + grouping: True + use_sample: False + dtype: bfloat16 + +text: + t5_checkpoint: ./weights/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth + t5_tokenizer: ./weights/Wan2.1-T2V-1.3B/google/umt5-xxl + dropout: 0.1 + dtype: bfloat16 + fsdp: + enabled: True + sharding_strategy: HYBRID_SHARD + +diffusion: + schedule: + type: lerp + T: 1000.0 + sampler: + type: euler + prediction_type: v_lerp + timesteps: + training: + type: logitnormal + loc: 0.0 + scale: 1.0 + sampling: + type: uniform_trailing + steps: 50 + shift: 5.0 + +audio: + vocal_separator: ./weights/audio_separator/Kim_Vocal_2.onnx + wav2vec_model: ./weights/whisper-large-v3 + +generation: + mode: "TIA" # TA, TIA + extract_audio_feat: True + seed: 666666 + frames: 97 + fps: 25 + height: 720 # 480 + width: 1280 # 832 + batch_size: 1 + output: + dir: ./output + sample_neg_prompt: '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走' + scale_t: 7.5 + scale_i: 4.0 + scale_a: 7.5 + # step_change: 980 \ No newline at end of file diff --git a/humo/configs/models/Wan_1.3B.yaml b/humo/configs/models/Wan_1.3B.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8202790fff334d1d6dae0343a95d2f8a0bbf571c --- /dev/null +++ b/humo/configs/models/Wan_1.3B.yaml @@ -0,0 +1,17 @@ +__object__: + path: ??? + name: ??? + args: as_params + +text_len: 512 +patch_size: [ 1, 2, 2 ] +dim: 1536 +ffn_dim: 8960 +freq_dim: 256 +model_type: "t2v" +num_heads: 12 +num_layers: 30 +window_size: [ -1, -1 ] +qk_norm: True +cross_attn_norm: True +eps: 1e-6 \ No newline at end of file diff --git a/humo/configs/models/Wan_1.3B_I2V.yaml b/humo/configs/models/Wan_1.3B_I2V.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9c3404757dc718f71b69427349479ff9bc33d761 --- /dev/null +++ b/humo/configs/models/Wan_1.3B_I2V.yaml @@ -0,0 +1,18 @@ +__object__: + path: ??? + name: ??? + args: as_params + +text_len: 512 +patch_size: [ 1, 2, 2 ] +dim: 1536 +ffn_dim: 8960 +freq_dim: 256 +in_dim: 36 +model_type: "i2v" +num_heads: 12 +num_layers: 30 +window_size: [ -1, -1 ] +qk_norm: True +cross_attn_norm: True +eps: 1e-6 diff --git a/humo/configs/models/Wan_14B.yaml b/humo/configs/models/Wan_14B.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5da9a868af9ae629ecd1a32638a94cc635b1cad9 --- /dev/null +++ b/humo/configs/models/Wan_14B.yaml @@ -0,0 +1,17 @@ +__object__: + path: ??? + name: ??? + args: as_params + +text_len: 512 +patch_size: [ 1, 2, 2 ] +dim: 5120 +ffn_dim: 13824 +freq_dim: 256 +model_type: "t2v" +num_heads: 40 +num_layers: 40 +window_size: [ -1, -1 ] +qk_norm: True +cross_attn_norm: True +eps: 1e-6 \ No newline at end of file diff --git a/humo/configs/models/Wan_14B_I2V.yaml b/humo/configs/models/Wan_14B_I2V.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c6d7d0082e688d151a58257a5371422b3e13bd31 --- /dev/null +++ b/humo/configs/models/Wan_14B_I2V.yaml @@ -0,0 +1,18 @@ +__object__: + path: ??? + name: ??? + args: as_params + +text_len: 512 +patch_size: [ 1, 2, 2 ] +dim: 5120 +ffn_dim: 13824 +freq_dim: 256 +in_dim: 36 +model_type: "i2v" +num_heads: 40 +num_layers: 40 +window_size: [ -1, -1 ] +qk_norm: True +cross_attn_norm: True +eps: 1e-6 \ No newline at end of file diff --git a/humo/generate.py b/humo/generate.py new file mode 100644 index 0000000000000000000000000000000000000000..c185f15a427ece6f72c75caa7dfd1e32ae278aae --- /dev/null +++ b/humo/generate.py @@ -0,0 +1,984 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Inference codes adapted from [SeedVR] +# https://github.com/ByteDance-Seed/SeedVR/blob/main/projects/inference_seedvr2_7b.py + +import math +import os +import gc +import random +import sys +import mediapy +import numpy as np +import torch +import torch.distributed as dist +from omegaconf import DictConfig, ListConfig, OmegaConf +from einops import rearrange +from omegaconf import OmegaConf +from PIL import Image, ImageOps +from torchvision.transforms import ToTensor +from tqdm import tqdm +from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.fsdp import ( + BackwardPrefetch, + FullyShardedDataParallel, + MixedPrecision, + ShardingStrategy, +) +from common.distributed import ( + get_device, + get_global_rank, + get_local_rank, + meta_param_init_fn, + meta_non_persistent_buffer_init_fn, + init_torch, +) +from common.distributed.advanced import ( + init_unified_parallel, + get_unified_parallel_world_size, + get_sequence_parallel_rank, + init_model_shard_cpu_group, +) +from common.logger import get_logger +from common.config import create_object +from common.distributed import get_device, get_global_rank +from torchvision.transforms import Compose, Normalize, ToTensor +from humo.models.wan_modules.t5 import T5EncoderModel +from humo.models.wan_modules.vae import WanVAE +from humo.models.utils.utils import tensor_to_video, prepare_json_dataset +from contextlib import contextmanager +import torch.cuda.amp as amp +from humo.models.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler +from humo.utils.audio_processor_whisper import AudioProcessor +from humo.utils.wav2vec import linear_interpolation_fps +from torchao.quantization import quantize_ + +import torch._dynamo as dynamo +dynamo.config.capture_scalar_outputs = True +torch.set_float32_matmul_precision("high") + +import torch +import torch.nn as nn +import transformer_engine.pytorch as te + +image_transform = Compose([ + ToTensor(), + Normalize(mean=0.5, std=0.5), +]) + +SIZE_CONFIGS = { + '720*1280': (720, 1280), + '1280*720': (1280, 720), + '480*832': (480, 832), + '832*480': (832, 480), + '1024*1024': (1024, 1024), +} + +def clever_format(nums, format="%.2f"): + from typing import Iterable + if not isinstance(nums, Iterable): + nums = [nums] + clever_nums = [] + for num in nums: + if num > 1e12: + clever_nums.append(format % (num / 1e12) + "T") + elif num > 1e9: + clever_nums.append(format % (num / 1e9) + "G") + elif num > 1e6: + clever_nums.append(format % (num / 1e6) + "M") + elif num > 1e3: + clever_nums.append(format % (num / 1e3) + "K") + else: + clever_nums.append(format % num + "B") + + clever_nums = clever_nums[0] if len(clever_nums) == 1 else (*clever_nums,) + + return clever_nums + + + +# --- put near your imports --- +import torch +import torch.nn as nn +import contextlib +import transformer_engine.pytorch as te + +# FP8 autocast compatibility for different TE versions +try: + # Preferred modern API + from transformer_engine.pytorch import fp8_autocast + try: + # Newer TE: use recipe-based API + from transformer_engine.common.recipe import DelayedScaling, Format + def make_fp8_ctx(enabled: bool = True): + if not enabled: + return contextlib.nullcontext() + fp8_recipe = DelayedScaling(fp8_format=Format.E4M3) # E4M3 format + return fp8_autocast(enabled=True, fp8_recipe=fp8_recipe) + except Exception: + # Very old variant that might still accept fp8_format directly + def make_fp8_ctx(enabled: bool = True): + # If TE doesn't have FP8Format, just no-op + if not hasattr(te, "FP8Format"): + return contextlib.nullcontext() + return te.fp8_autocast(enabled=enabled, fp8_format=te.FP8Format.E4M3) +except Exception: + # TE not present or totally incompatible — no-op + def make_fp8_ctx(enabled: bool = True): + return contextlib.nullcontext() + + +# TE sometimes exposes Linear at different paths; this normalizes it. +try: + TELinear = te.Linear +except AttributeError: # very old layouts + from transformer_engine.pytorch.modules.linear import Linear as TELinear # type: ignore + +# --- near imports --- +import torch +import torch.nn as nn +import transformer_engine.pytorch as te + +try: + TELinear = te.Linear +except AttributeError: + from transformer_engine.pytorch.modules.linear import Linear as TELinear # type: ignore + +import torch +import torch.nn as nn +import transformer_engine.pytorch as te + +try: + TELinear = te.Linear +except AttributeError: + from transformer_engine.pytorch.modules.linear import Linear as TELinear # type: ignore + +def _default_te_allow(fullname: str, lin: nn.Linear) -> bool: + """ + Allow TE only where it's shape-safe & beneficial. + Skip small/special layers (time/timestep/pos embeds, heads). + Enforce multiples of 16 for in/out features (FP8 kernel friendly). + Also skip very small projections likely to see M=1. + """ + blocked_keywords = ( + "time_embedding", "timestep", "time_embed", + "time_projection", "pos_embedding", "pos_embed", + "to_logits", "logits", "final_proj", "proj_out", "output_projection", + ) + if any(k in fullname for k in blocked_keywords): + return False + + # TE FP8 kernels like K, N divisible by 16 + if lin.in_features % 16 != 0 or lin.out_features % 16 != 0: + return False + + # Heuristic: avoid tiny layers; keeps attention/MLP, skips small MLPs + if lin.in_features < 512 or lin.out_features < 512: + return False + + # Whitelist: only convert inside transformer blocks if you know their prefix + # This further reduces risk of catching special heads elsewhere. + allowed_context = ("blocks", "layers", "transformer", "attn", "mlp", "ffn") + if not any(tok in fullname for tok in allowed_context): + return False + + return True + +@torch.no_grad() +def convert_linears_to_te_fp8(module: nn.Module, allow_pred=_default_te_allow, _prefix=""): + for name, child in list(module.named_children()): + full = f"{_prefix}.{name}" if _prefix else name + convert_linears_to_te_fp8(child, allow_pred, full) + + if isinstance(child, nn.Linear): + if allow_pred is not None and not allow_pred(full, child): + continue + + te_lin = TELinear( + in_features=child.in_features, + out_features=child.out_features, + bias=(child.bias is not None), + params_dtype=torch.bfloat16, + ).to(child.weight.device) + + te_lin.weight.copy_(child.weight.to(te_lin.weight.dtype)) + if child.bias is not None: + te_lin.bias.copy_(child.bias.to(te_lin.bias.dtype)) + + setattr(module, name, te_lin) + return module + +class Generator(): + def __init__(self, config: DictConfig): + self.config = config.copy() + OmegaConf.set_readonly(self.config, True) + self.logger = get_logger(self.__class__.__name__) + + # init_torch(cudnn_benchmark=False) + self.configure_models() + + def entrypoint(self): + + self.inference_loop() + + def get_fsdp_sharding_config(self, sharding_strategy, device_mesh_config): + device_mesh = None + fsdp_strategy = ShardingStrategy[sharding_strategy] + if ( + fsdp_strategy in [ShardingStrategy._HYBRID_SHARD_ZERO2, ShardingStrategy.HYBRID_SHARD] + and device_mesh_config is not None + ): + device_mesh = init_device_mesh("cuda", tuple(device_mesh_config)) + return device_mesh, fsdp_strategy + + + def configure_models(self): + self.configure_dit_model(device="cuda") + + self.dit.eval().to("cuda") + convert_linears_to_te_fp8(self.dit) + + self.dit = torch.compile(self.dit, ) + + + self.configure_vae_model(device="cuda") + if self.config.generation.get('extract_audio_feat', False): + self.configure_wav2vec(device="cpu") + self.configure_text_model(device="cuda") + + # # Initialize fsdp. + # self.configure_dit_fsdp_model() + # self.configure_text_fsdp_model() + + # quantize_(self.text_encoder, Int8WeightOnlyConfig()) + # quantize_(self.dit, Float8DynamicActivationFloat8WeightConfig()) + + + def configure_dit_model(self, device=get_device()): + + init_unified_parallel(self.config.dit.sp_size) + self.sp_size = get_unified_parallel_world_size() + + # Create DiT model on meta, then mark dtype as bfloat16 (no real allocation yet). + init_device = "meta" + with torch.device(init_device): + self.dit = create_object(self.config.dit.model) + self.dit = self.dit.to(dtype=torch.bfloat16) # or: self.dit.bfloat16() + self.logger.info(f"Load DiT model on {init_device}.") + self.dit.eval().requires_grad_(False) + + # Load dit checkpoint. + path = self.config.dit.checkpoint_dir + + def _cast_state_dict_to_bf16(state): + for k, v in state.items(): + if isinstance(v, torch.Tensor) and v.is_floating_point(): + state[k] = v.to(dtype=torch.bfloat16, copy=False) + return state + + if path.endswith(".pth"): + # Load to CPU first; we’ll move the model later. + state = torch.load(path, map_location="cpu", mmap=True) + state = _cast_state_dict_to_bf16(state) + missing_keys, unexpected_keys = self.dit.load_state_dict(state, strict=False, assign=True) + self.logger.info( + f"dit loaded from {path}. Missing keys: {len(missing_keys)}, Unexpected keys: {len(unexpected_keys)}" + ) + else: + from safetensors.torch import load_file + import json + def load_custom_sharded_weights(model_dir, base_name): + index_path = f"{model_dir}/{base_name}.safetensors.index.json" + with open(index_path, "r") as f: + index = json.load(f) + weight_map = index["weight_map"] + shard_files = set(weight_map.values()) + state_dict = {} + for shard_file in shard_files: + shard_path = f"{model_dir}/{shard_file}" + # Load on CPU, then cast to bf16; we’ll move the whole module later. + shard_state = load_file(shard_path, device="cpu") + shard_state = {k: (v.to(dtype=torch.bfloat16, copy=False) if v.is_floating_point() else v) + for k, v in shard_state.items()} + state_dict.update(shard_state) + return state_dict + + state = load_custom_sharded_weights(path, 'humo') + self.dit.load_state_dict(state, strict=False, assign=True) + + self.dit = meta_non_persistent_buffer_init_fn(self.dit) + + target_device = get_device() if device in [get_device(), "cuda"] else device + self.dit.to(target_device) # dtype already bf16 + + # Print model size. + params = sum(p.numel() for p in self.dit.parameters()) + self.logger.info( + f"[RANK:{get_global_rank()}] DiT Parameters: {clever_format(params, '%.3f')}" + ) + + + def configure_vae_model(self, device=get_device()): + self.vae_stride = self.config.vae.vae_stride + self.vae = WanVAE( + vae_pth=self.config.vae.checkpoint, + device=device) + + if self.config.generation.height == 480: + self.zero_vae = torch.load(self.config.dit.zero_vae_path) + elif self.config.generation.height == 720: + self.zero_vae = torch.load(self.config.dit.zero_vae_720p_path) + else: + raise ValueError(f"Unsupported height {self.config.generation.height} for zero-vae.") + + def configure_wav2vec(self, device=get_device()): + audio_separator_model_file = self.config.audio.vocal_separator + wav2vec_model_path = self.config.audio.wav2vec_model + + self.audio_processor = AudioProcessor( + 16000, + 25, + wav2vec_model_path, + "all", + audio_separator_model_file, + None, # not seperate + os.path.join(self.config.generation.output.dir, "vocals"), + device=device, + ) + + def configure_text_model(self, device=get_device()): + self.text_encoder = T5EncoderModel( + text_len=self.config.dit.model.text_len, + dtype=torch.bfloat16, + device=device, + checkpoint_path=self.config.text.t5_checkpoint, + tokenizer_path=self.config.text.t5_tokenizer, + ) + + + def configure_dit_fsdp_model(self): + from humo.models.wan_modules.model_humo import WanAttentionBlock + + dit_blocks = (WanAttentionBlock,) + + # Init model_shard_cpu_group for saving checkpoint with sharded state_dict. + init_model_shard_cpu_group( + self.config.dit.fsdp.sharding_strategy, + self.config.dit.fsdp.get("device_mesh", None), + ) + + # Assert that dit has wrappable blocks. + assert any(isinstance(m, dit_blocks) for m in self.dit.modules()) + + # Define wrap policy on all dit blocks. + def custom_auto_wrap_policy(module, recurse, *args, **kwargs): + return recurse or isinstance(module, dit_blocks) + + # Configure FSDP settings. + device_mesh, fsdp_strategy = self.get_fsdp_sharding_config( + self.config.dit.fsdp.sharding_strategy, + self.config.dit.fsdp.get("device_mesh", None), + ) + settings = dict( + auto_wrap_policy=custom_auto_wrap_policy, + sharding_strategy=fsdp_strategy, + backward_prefetch=BackwardPrefetch.BACKWARD_PRE, + device_id=get_local_rank(), + use_orig_params=False, + sync_module_states=True, + forward_prefetch=True, + limit_all_gathers=False, # False for ZERO2. + mixed_precision=MixedPrecision( + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + buffer_dtype=torch.float32, + ), + device_mesh=device_mesh, + param_init_fn=meta_param_init_fn, + ) + + # Apply FSDP. + self.dit = FullyShardedDataParallel(self.dit, **settings) + # self.dit.to(get_device()) + + + def configure_text_fsdp_model(self): + # If FSDP is not enabled, put text_encoder to GPU and return. + if not self.config.text.fsdp.enabled: + self.text_encoder.to(get_device()) + return + + # from transformers.models.t5.modeling_t5 import T5Block + from humo.models.wan_modules.t5 import T5SelfAttention + + text_blocks = (torch.nn.Embedding, T5SelfAttention) + # text_blocks_names = ("QWenBlock", "QWenModel") # QWen cannot be imported. Use str. + + def custom_auto_wrap_policy(module, recurse, *args, **kwargs): + return ( + recurse + or isinstance(module, text_blocks) + ) + + # Apply FSDP. + text_encoder_dtype = getattr(torch, self.config.text.dtype) + device_mesh, fsdp_strategy = self.get_fsdp_sharding_config( + self.config.text.fsdp.sharding_strategy, + self.config.text.fsdp.get("device_mesh", None), + ) + self.text_encoder = FullyShardedDataParallel( + module=self.text_encoder, + auto_wrap_policy=custom_auto_wrap_policy, + sharding_strategy=fsdp_strategy, + backward_prefetch=BackwardPrefetch.BACKWARD_PRE, + device_id=get_local_rank(), + use_orig_params=False, + sync_module_states=False, + forward_prefetch=True, + limit_all_gathers=True, + mixed_precision=MixedPrecision( + param_dtype=text_encoder_dtype, + reduce_dtype=text_encoder_dtype, + buffer_dtype=text_encoder_dtype, + ), + device_mesh=device_mesh, + ) + self.text_encoder.to(get_device()).requires_grad_(False) + + + def load_image_latent_ref_id(self, path: str, size, device): + # Load size. + h, w = size[1], size[0] + + # Load image. + if len(path) > 1 and not isinstance(path, str): + ref_vae_latents = [] + for image_path in path: + with Image.open(image_path) as img: + img = img.convert("RGB") + + # Calculate the required size to keep aspect ratio and fill the rest with padding. + img_ratio = img.width / img.height + target_ratio = w / h + + if img_ratio > target_ratio: # Image is wider than target + new_width = w + new_height = int(new_width / img_ratio) + else: # Image is taller than target + new_height = h + new_width = int(new_height * img_ratio) + + # img = img.resize((new_width, new_height), Image.ANTIALIAS) + img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) + + # Create a new image with the target size and place the resized image in the center + delta_w = w - img.size[0] + delta_h = h - img.size[1] + padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2)) + new_img = ImageOps.expand(img, padding, fill=(255, 255, 255)) + + # Transform to tensor and normalize. + transform = Compose( + [ + ToTensor(), + Normalize(0.5, 0.5), + ] + ) + new_img = transform(new_img) + # img_vae_latent = self.vae_encode([new_img.unsqueeze(1)])[0] + img_vae_latent = self.vae.encode([new_img.unsqueeze(1)], device) + ref_vae_latents.append(img_vae_latent[0]) + + return [torch.cat(ref_vae_latents, dim=1)] + else: + if not isinstance(path, str): + path = path[0] + with Image.open(path) as img: + img = img.convert("RGB") + + # Calculate the required size to keep aspect ratio and fill the rest with padding. + img_ratio = img.width / img.height + target_ratio = w / h + + if img_ratio > target_ratio: # Image is wider than target + new_width = w + new_height = int(new_width / img_ratio) + else: # Image is taller than target + new_height = h + new_width = int(new_height * img_ratio) + + # img = img.resize((new_width, new_height), Image.ANTIALIAS) + img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) + + # Create a new image with the target size and place the resized image in the center + delta_w = w - img.size[0] + delta_h = h - img.size[1] + padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2)) + new_img = ImageOps.expand(img, padding, fill=(255, 255, 255)) + + # Transform to tensor and normalize. + transform = Compose( + [ + ToTensor(), + Normalize(0.5, 0.5), + ] + ) + new_img = transform(new_img) + img_vae_latent = self.vae.encode([new_img.unsqueeze(1)], device) + + # Vae encode. + return img_vae_latent + + def get_audio_emb_window(self, audio_emb, frame_num, frame0_idx, audio_shift=2): + zero_audio_embed = torch.zeros((audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device) + zero_audio_embed_3 = torch.zeros((3, audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device) # device=audio_emb.device + iter_ = 1 + (frame_num - 1) // 4 + audio_emb_wind = [] + for lt_i in range(iter_): + if lt_i == 0: + st = frame0_idx + lt_i - 2 + ed = frame0_idx + lt_i + 3 + wind_feat = torch.stack([ + audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed + for i in range(st, ed) + ], dim=0) + wind_feat = torch.cat((zero_audio_embed_3, wind_feat), dim=0) + else: + st = frame0_idx + 1 + 4 * (lt_i - 1) - audio_shift + ed = frame0_idx + 1 + 4 * lt_i + audio_shift + wind_feat = torch.stack([ + audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed + for i in range(st, ed) + ], dim=0) + audio_emb_wind.append(wind_feat) + audio_emb_wind = torch.stack(audio_emb_wind, dim=0) + + return audio_emb_wind, ed - audio_shift + + def audio_emb_enc(self, audio_emb, wav_enc_type="whisper"): + if wav_enc_type == "wav2vec": + feat_merge = audio_emb + elif wav_enc_type == "whisper": + feat0 = linear_interpolation_fps(audio_emb[:, :, 0: 8].mean(dim=2), 50, 25) + feat1 = linear_interpolation_fps(audio_emb[:, :, 8: 16].mean(dim=2), 50, 25) + feat2 = linear_interpolation_fps(audio_emb[:, :, 16: 24].mean(dim=2), 50, 25) + feat3 = linear_interpolation_fps(audio_emb[:, :, 24: 32].mean(dim=2), 50, 25) + feat4 = linear_interpolation_fps(audio_emb[:, :, 32], 50, 25) + feat_merge = torch.stack([feat0, feat1, feat2, feat3, feat4], dim=2)[0] + else: + raise ValueError(f"Unsupported wav_enc_type: {wav_enc_type}") + + return feat_merge + + def parse_output(self, output): + latent = output[0] + mask = None + return latent, mask + + def forward_tia(self, latents, timestep, t, step_change, arg_tia, arg_ti, arg_i, arg_null): + pos_tia, _ = self.parse_output(self.dit( + latents, t=timestep, **arg_tia + )) + torch.cuda.empty_cache() + + pos_ti, _ = self.parse_output(self.dit( + latents, t=timestep, **arg_ti + )) + torch.cuda.empty_cache() + + if t > step_change: + neg, _ = self.parse_output(self.dit( + latents, t=timestep, **arg_i + )) # img included in null, same with official Wan-2.1 + torch.cuda.empty_cache() + + noise_pred = self.config.generation.scale_a * (pos_tia - pos_ti) + \ + self.config.generation.scale_t * (pos_ti - neg) + \ + neg + else: + neg, _ = self.parse_output(self.dit( + latents, t=timestep, **arg_null + )) # img not included in null + torch.cuda.empty_cache() + + noise_pred = self.config.generation.scale_a * (pos_tia - pos_ti) + \ + (self.config.generation.scale_t - 2.0) * (pos_ti - neg) + \ + neg + return noise_pred + + def forward_ti(self, latents, timestep, t, step_change, arg_ti, arg_t, arg_i, arg_null): + # Positive with text+image (no audio) + pos_ti, _ = self.parse_output(self.dit( + latents, t=timestep, **arg_ti + )) + torch.cuda.empty_cache() + + # Positive with text only (no image, no audio) + pos_t, _ = self.parse_output(self.dit( + latents, t=timestep, **arg_t + )) + torch.cuda.empty_cache() + + # Negative branch: before step_change, don't include image in null; after, include image (like Wan-2.1) + if t > step_change: + neg, _ = self.parse_output(self.dit( + latents, t=timestep, **arg_i + )) # img included in null + else: + neg, _ = self.parse_output(self.dit( + latents, t=timestep, **arg_null + )) # img NOT included in null + torch.cuda.empty_cache() + + # Guidance blend: replace "scale_a" below with "scale_i" if you add a separate image scale in config + noise_pred = self.config.generation.scale_a * (pos_ti - pos_t) + \ + self.config.generation.scale_t * (pos_t - neg) + \ + neg + return noise_pred + + def forward_ta(self, latents, timestep, arg_ta, arg_t, arg_null): + pos_ta, _ = self.parse_output(self.dit( + latents, t=timestep, **arg_ta + )) + torch.cuda.empty_cache() + + pos_t, _ = self.parse_output(self.dit( + latents, t=timestep, **arg_t + )) + torch.cuda.empty_cache() + + neg, _ = self.parse_output(self.dit( + latents, t=timestep, **arg_null + )) + torch.cuda.empty_cache() + + noise_pred = self.config.generation.scale_a * (pos_ta - pos_t) + \ + self.config.generation.scale_t * (pos_t - neg) + \ + neg + return noise_pred + + @torch.no_grad() + def inference(self, + input_prompt, + img_path, + audio_path, + size=(1280, 720), + frame_num=81, + shift=5.0, + sample_solver='unipc', + inference_mode='TIA', + sampling_steps=50, + n_prompt="", + seed=-1, + tea_cache_l1_thresh = 0.0, + device = get_device(), + ): + + print("inference started") + + # self.vae.model.to(device=device) + if img_path is not None: + latents_ref = self.load_image_latent_ref_id(img_path, size, device) + else: + latents_ref = [torch.zeros(16, 1, size[1]//8, size[0]//8).to(device)] + + # self.vae.model.to(device="cpu") + + print("vae finished") + + latents_ref_neg = [torch.zeros_like(latent_ref) for latent_ref in latents_ref] + + # audio + if audio_path is not None: + if self.config.generation.extract_audio_feat: + self.audio_processor.whisper.to(device=device) + audio_emb, audio_length = self.audio_processor.preprocess(audio_path) + self.audio_processor.whisper.to(device='cpu') + else: + audio_emb_path = audio_path.replace(".wav", ".pt") + audio_emb = torch.load(audio_emb_path).to(device=device) + audio_emb = self.audio_emb_enc(audio_emb, wav_enc_type="whisper") + self.logger.info("使用预先提取好的音频特征: %s", audio_emb_path) + else: + audio_emb = torch.zeros(frame_num, 5, 1280).to(device) + + frame_num = frame_num if frame_num != -1 else audio_length + frame_num = 4 * ((frame_num - 1) // 4) + 1 + audio_emb, _ = self.get_audio_emb_window(audio_emb, frame_num, frame0_idx=0) + zero_audio_pad = torch.zeros(latents_ref[0].shape[1], *audio_emb.shape[1:]).to(audio_emb.device) + audio_emb = torch.cat([audio_emb, zero_audio_pad], dim=0) + audio_emb = [audio_emb.to(device)] + audio_emb_neg = [torch.zeros_like(audio_emb[0])] + + # preprocess + self.patch_size = self.config.dit.model.patch_size + F = frame_num + target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1 + latents_ref[0].shape[1], + size[1] // self.vae_stride[1], + size[0] // self.vae_stride[2]) + + seq_len = math.ceil((target_shape[2] * target_shape[3]) / + (self.patch_size[1] * self.patch_size[2]) * + target_shape[1] / self.sp_size) * self.sp_size + + if n_prompt == "": + n_prompt = self.config.generation.sample_neg_prompt + seed = seed if seed >= 0 else random.randint(0, sys.maxsize) + seed_g = torch.Generator(device=device) + seed_g.manual_seed(seed) + + # self.text_encoder.model.to(device) + context = self.text_encoder([input_prompt], device) + context_null = self.text_encoder([n_prompt], device) + # self.text_encoder.model.cpu() + + print("text encoder finished") + + noise = [ + torch.randn( + target_shape[0], + target_shape[1], # - latents_ref[0].shape[1], + target_shape[2], + target_shape[3], + dtype=torch.float32, + device=device, + generator=seed_g) + ] + + @contextmanager + def noop_no_sync(): + yield + + no_sync = getattr(self.dit, 'no_sync', noop_no_sync) + step_change = self.config.generation.step_change # 980 + + # evaluation mode + with make_fp8_ctx(True), torch.autocast('cuda', dtype=torch.bfloat16), torch.no_grad(), no_sync(): + + if sample_solver == 'unipc': + sample_scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=1000, + shift=1, + use_dynamic_shifting=False) + sample_scheduler.set_timesteps( + sampling_steps, device=device, shift=shift) + timesteps = sample_scheduler.timesteps + + # sample videos + latents = noise + + msk = torch.ones(4, target_shape[1], target_shape[2], target_shape[3], device=get_device()) + msk[:,:-latents_ref[0].shape[1]] = 0 + + zero_vae = self.zero_vae[:, :(target_shape[1]-latents_ref[0].shape[1])].to( + device=get_device(), dtype=latents_ref[0].dtype) + y_c = torch.cat([ + zero_vae, + latents_ref[0] + ], dim=1) + y_c = [torch.concat([msk, y_c])] + + y_null = self.zero_vae[:, :target_shape[1]].to( + device=get_device(), dtype=latents_ref[0].dtype) + y_null = [torch.concat([msk, y_null])] + + tea_cache_l1_thresh = tea_cache_l1_thresh + tea_cache_model_id = "Wan2.1-T2V-14B" + + arg_null = {'seq_len': seq_len, 'audio': audio_emb_neg, 'y': y_null, 'context': context_null, "tea_cache": TeaCache(sampling_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None and tea_cache_l1_thresh > 0 else None} + arg_t = {'seq_len': seq_len, 'audio': audio_emb_neg, 'y': y_null, 'context': context, "tea_cache": TeaCache(sampling_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None and tea_cache_l1_thresh > 0 else None} + arg_i = {'seq_len': seq_len, 'audio': audio_emb_neg, 'y': y_c, 'context': context_null, "tea_cache": TeaCache(sampling_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None and tea_cache_l1_thresh > 0 else None} + arg_ti = {'seq_len': seq_len, 'audio': audio_emb_neg, 'y': y_c, 'context': context, "tea_cache": TeaCache(sampling_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None and tea_cache_l1_thresh > 0 else None} + arg_ta = {'seq_len': seq_len, 'audio': audio_emb, 'y': y_null, 'context': context, "tea_cache": TeaCache(sampling_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None and tea_cache_l1_thresh > 0 else None} + arg_tia = {'seq_len': seq_len, 'audio': audio_emb, 'y': y_c, 'context': context, "tea_cache": TeaCache(sampling_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None and tea_cache_l1_thresh > 0 else None} + + torch.cuda.empty_cache() + # self.dit.to(device=get_device()) + for _, t in enumerate(tqdm(timesteps)): + timestep = [t] + timestep = torch.stack(timestep) + + if inference_mode == "TIA": + noise_pred = self.forward_tia(latents, timestep, t, step_change, + arg_tia, arg_ti, arg_i, arg_null) + elif inference_mode == "TA": + noise_pred = self.forward_ta(latents, timestep, arg_ta, arg_t, arg_null) + elif inference_mode == "TI": + noise_pred = self.forward_ti(latents, timestep, t, step_change, + arg_ti, arg_t, arg_i, arg_null) + else: + raise ValueError(f"Unsupported generation mode: {self.config.generation.mode}") + + temp_x0 = sample_scheduler.step( + noise_pred.unsqueeze(0), + t, + latents[0].unsqueeze(0), + return_dict=False, + generator=seed_g)[0] + latents = [temp_x0.squeeze(0)] + + del timestep + torch.cuda.empty_cache() + + x0 = latents + x0 = [x0_[:,:-latents_ref[0].shape[1]] for x0_ in x0] + + # if offload_model: + # self.dit.cpu() + + print("dit finished") + + torch.cuda.empty_cache() + # if get_local_rank() == 0: + # self.vae.model.to(device=device) + videos = self.vae.decode(x0) + # self.vae.model.to(device="cpu") + + print("vae 2 finished") + + del noise, latents, noise_pred + del audio_emb, audio_emb_neg, latents_ref, latents_ref_neg, context, context_null + del x0, temp_x0 + del sample_scheduler + torch.cuda.empty_cache() + gc.collect() + torch.cuda.synchronize() + if dist.is_initialized(): + dist.barrier() + + return videos[0] # if get_local_rank() == 0 else None + + + def inference_loop(self, prompt, ref_img_path, audio_path, output_dir, filename, inference_mode = "TIA", width = 832, height = 480, steps=50, frames = 97, tea_cache_l1_thresh = 0.0, seed = 0): + + video = self.inference( + prompt, + ref_img_path, + audio_path, + size=SIZE_CONFIGS[f"{width}*{height}"], + frame_num=frames, + shift=self.config.diffusion.timesteps.sampling.shift, + sample_solver='unipc', + sampling_steps=steps, + inference_mode = inference_mode, + tea_cache_l1_thresh = tea_cache_l1_thresh, + seed=seed + ) + + torch.cuda.empty_cache() + gc.collect() + + # Save samples. + if get_sequence_parallel_rank() == 0: + pathname = self.save_sample( + sample=video, + audio_path=audio_path, + output_dir = output_dir, + filename=filename, + ) + self.logger.info(f"Finished {filename}, saved to {pathname}.") + + del video, prompt + torch.cuda.empty_cache() + gc.collect() + + + def save_sample(self, *, sample: torch.Tensor, audio_path: str, output_dir: str, filename: str): + gen_config = self.config.generation + # Prepare file path. + extension = ".mp4" if sample.ndim == 4 else ".png" + filename += extension + pathname = os.path.join(output_dir, filename) + # Convert sample. + sample = sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).to("cpu", torch.uint8) + sample = rearrange(sample, "c t h w -> t h w c") + # Save file. + if sample.ndim == 4: + if audio_path is not None: + tensor_to_video( + sample.numpy(), + pathname, + audio_path, + fps=gen_config.fps) + else: + mediapy.write_video( + path=pathname, + images=sample.numpy(), + fps=gen_config.fps, + ) + else: + raise ValueError + return pathname + + + def prepare_positive_prompts(self): + pos_prompts = self.config.generation.positive_prompt + if pos_prompts.endswith(".json"): + pos_prompts = prepare_json_dataset(pos_prompts) + else: + raise NotImplementedError + assert isinstance(pos_prompts, ListConfig) + + return pos_prompts + +class TeaCache: + def __init__(self, num_inference_steps, rel_l1_thresh, model_id): + self.num_inference_steps = num_inference_steps + self.step = 0 + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = None + self.rel_l1_thresh = rel_l1_thresh + self.previous_residual = None + self.previous_hidden_states = None + + self.coefficients_dict = { + "Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02], + "Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01], + "Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01], + "Wan2.1-I2V-14B-720P": [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02], + } + if model_id not in self.coefficients_dict: + supported_model_ids = ", ".join([i for i in self.coefficients_dict]) + raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).") + self.coefficients = self.coefficients_dict[model_id] + + def check(self, dit, x, t_mod): + modulated_inp = t_mod.clone() + if self.step == 0 or self.step == self.num_inference_steps - 1: + should_calc = True + self.accumulated_rel_l1_distance = 0 + else: + coefficients = self.coefficients + rescale_func = np.poly1d(coefficients) + self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) + if self.accumulated_rel_l1_distance < self.rel_l1_thresh: + should_calc = False + else: + should_calc = True + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = modulated_inp + self.step += 1 + if self.step == self.num_inference_steps: + self.step = 0 + if should_calc: + self.previous_hidden_states = x.clone() + return not should_calc + + def store(self, hidden_states): + if self.previous_hidden_states is None: + return + self.previous_residual = hidden_states - self.previous_hidden_states + self.previous_hidden_states = None + + def update(self, hidden_states): + hidden_states = hidden_states + self.previous_residual + return hidden_states \ No newline at end of file diff --git a/humo/generate_1_7B.py b/humo/generate_1_7B.py new file mode 100644 index 0000000000000000000000000000000000000000..debd2cad292734d94ba30deb94725c2be7df302e --- /dev/null +++ b/humo/generate_1_7B.py @@ -0,0 +1,622 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Inference codes adapted from [SeedVR] +# https://github.com/ByteDance-Seed/SeedVR/blob/main/projects/inference_seedvr2_7b.py + +import math +import os +import gc +import random +import sys +import mediapy +import torch +import torch.distributed as dist +from omegaconf import DictConfig, ListConfig, OmegaConf +from einops import rearrange +from omegaconf import OmegaConf +from PIL import Image, ImageOps +from torchvision.transforms import ToTensor +from tqdm import tqdm +from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.fsdp import ( + BackwardPrefetch, + FullyShardedDataParallel, + MixedPrecision, + ShardingStrategy, +) +from common.distributed import ( + get_device, + get_global_rank, + get_local_rank, + meta_param_init_fn, + meta_non_persistent_buffer_init_fn, + init_torch, +) +from common.distributed.advanced import ( + init_unified_parallel, + get_unified_parallel_world_size, + get_sequence_parallel_rank, + init_model_shard_cpu_group, +) +from common.logger import get_logger +from common.config import create_object +from common.distributed import get_device, get_global_rank +from torchvision.transforms import Compose, Normalize, ToTensor +from humo.models.wan_modules.t5 import T5EncoderModel +from humo.models.wan_modules.vae import WanVAE +from humo.models.utils.utils import tensor_to_video, prepare_json_dataset +from contextlib import contextmanager +import torch.cuda.amp as amp +from humo.models.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler +from humo.utils.audio_processor_whisper import AudioProcessor +from humo.utils.wav2vec import linear_interpolation_fps + + +image_transform = Compose([ + ToTensor(), + Normalize(mean=0.5, std=0.5), +]) + +SIZE_CONFIGS = { + '720*1280': (720, 1280), + '1280*720': (1280, 720), + '480*832': (480, 832), + '832*480': (832, 480), + '1024*1024': (1024, 1024), +} + +def clever_format(nums, format="%.2f"): + from typing import Iterable + if not isinstance(nums, Iterable): + nums = [nums] + clever_nums = [] + for num in nums: + if num > 1e12: + clever_nums.append(format % (num / 1e12) + "T") + elif num > 1e9: + clever_nums.append(format % (num / 1e9) + "G") + elif num > 1e6: + clever_nums.append(format % (num / 1e6) + "M") + elif num > 1e3: + clever_nums.append(format % (num / 1e3) + "K") + else: + clever_nums.append(format % num + "B") + + clever_nums = clever_nums[0] if len(clever_nums) == 1 else (*clever_nums,) + + return clever_nums + + +class Generator(): + def __init__(self, config: DictConfig): + self.config = config.copy() + OmegaConf.set_readonly(self.config, True) + self.logger = get_logger(self.__class__.__name__) + self.configure_models() + + # init_torch(cudnn_benchmark=False) + + def get_fsdp_sharding_config(self, sharding_strategy, device_mesh_config): + device_mesh = None + fsdp_strategy = ShardingStrategy[sharding_strategy] + if ( + fsdp_strategy in [ShardingStrategy._HYBRID_SHARD_ZERO2, ShardingStrategy.HYBRID_SHARD] + and device_mesh_config is not None + ): + device_mesh = init_device_mesh("cuda", tuple(device_mesh_config)) + return device_mesh, fsdp_strategy + + def configure_models(self): + self.configure_dit_model(device="cpu") + self.configure_vae_model() + if self.config.generation.get('extract_audio_feat', False): + self.configure_wav2vec(device="cpu") + self.configure_text_model(device="cpu") + + # Initialize fsdp. + self.configure_dit_fsdp_model() + self.configure_text_fsdp_model() + + def configure_dit_model(self, device=get_device()): + + init_unified_parallel(self.config.dit.sp_size) + self.sp_size = get_unified_parallel_world_size() + + # Create dit model. + init_device = "meta" + with torch.device(init_device): + self.dit = create_object(self.config.dit.model) + self.logger.info(f"Load DiT model on {init_device}.") + self.dit.eval().requires_grad_(False) + + # Load dit checkpoint. + path = self.config.dit.checkpoint_dir + if path.endswith(".pth"): + state = torch.load(path, map_location=device, mmap=True) + missing_keys, unexpected_keys = self.dit.load_state_dict(state, strict=False, assign=True) + self.logger.info( + f"dit loaded from {path}. " + f"Missing keys: {len(missing_keys)}, " + f"Unexpected keys: {len(unexpected_keys)}" + ) + else: + from safetensors.torch import load_file + import json + def load_custom_sharded_weights(model_dir, base_name, device=device): + index_path = f"{model_dir}/{base_name}.safetensors.index.json" + with open(index_path, "r") as f: + index = json.load(f) + weight_map = index["weight_map"] + shard_files = set(weight_map.values()) + state_dict = {} + for shard_file in shard_files: + shard_path = f"{model_dir}/{shard_file}" + shard_state = load_file(shard_path) + shard_state = {k: v.to(device) for k, v in shard_state.items()} + state_dict.update(shard_state) + return state_dict + state = load_custom_sharded_weights(path, 'humo', device) + self.dit.load_state_dict(state, strict=False, assign=True) + + self.dit = meta_non_persistent_buffer_init_fn(self.dit) + if device in [get_device(), "cuda"]: + self.dit.to(get_device()) + + # Print model size. + params = sum(p.numel() for p in self.dit.parameters()) + self.logger.info( + f"[RANK:{get_global_rank()}] DiT Parameters: {clever_format(params, '%.3f')}" + ) + + def configure_vae_model(self, device=get_device()): + self.vae_stride = self.config.vae.vae_stride + self.vae = WanVAE( + vae_pth=self.config.vae.checkpoint, + device=device) + + if self.config.generation.height == 480: + self.zero_vae = torch.load(self.config.dit.zero_vae_path) + elif self.config.generation.height == 720: + self.zero_vae = torch.load(self.config.dit.zero_vae_720p_path) + else: + raise ValueError(f"Unsupported height {self.config.generation.height} for zero-vae.") + + def configure_wav2vec(self, device=get_device()): + audio_separator_model_file = self.config.audio.vocal_separator + wav2vec_model_path = self.config.audio.wav2vec_model + + self.audio_processor = AudioProcessor( + 16000, + 25, + wav2vec_model_path, + "all", + audio_separator_model_file, + None, # not seperate + os.path.join(self.config.generation.output.dir, "vocals"), + device=device, + ) + + def configure_text_model(self, device=get_device()): + self.text_encoder = T5EncoderModel( + text_len=self.config.dit.model.text_len, + dtype=torch.bfloat16, + device=device, + checkpoint_path=self.config.text.t5_checkpoint, + tokenizer_path=self.config.text.t5_tokenizer, + ) + + + def configure_dit_fsdp_model(self): + self.dit.to(get_device()) + + return + + + def configure_text_fsdp_model(self): + self.text_encoder.to(get_device()) + + return + + + def load_image_latent_ref_id(self, path: str, size, device): + # Load size. + h, w = size[1], size[0] + + # Load image. + if len(path) > 1 and not isinstance(path, str): + ref_vae_latents = [] + for image_path in path: + with Image.open(image_path) as img: + img = img.convert("RGB") + + # Calculate the required size to keep aspect ratio and fill the rest with padding. + img_ratio = img.width / img.height + target_ratio = w / h + + if img_ratio > target_ratio: # Image is wider than target + new_width = w + new_height = int(new_width / img_ratio) + else: # Image is taller than target + new_height = h + new_width = int(new_height * img_ratio) + + # img = img.resize((new_width, new_height), Image.ANTIALIAS) + img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) + + # Create a new image with the target size and place the resized image in the center + delta_w = w - img.size[0] + delta_h = h - img.size[1] + padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2)) + new_img = ImageOps.expand(img, padding, fill=(255, 255, 255)) + + # Transform to tensor and normalize. + transform = Compose( + [ + ToTensor(), + Normalize(0.5, 0.5), + ] + ) + new_img = transform(new_img) + # img_vae_latent = self.vae_encode([new_img.unsqueeze(1)])[0] + img_vae_latent = self.vae.encode([new_img.unsqueeze(1)], device) + ref_vae_latents.append(img_vae_latent[0]) + + return [torch.cat(ref_vae_latents, dim=1)] + else: + if not isinstance(path, str): + path = path[0] + with Image.open(path) as img: + img = img.convert("RGB") + + # Calculate the required size to keep aspect ratio and fill the rest with padding. + img_ratio = img.width / img.height + target_ratio = w / h + + if img_ratio > target_ratio: # Image is wider than target + new_width = w + new_height = int(new_width / img_ratio) + else: # Image is taller than target + new_height = h + new_width = int(new_height * img_ratio) + + # img = img.resize((new_width, new_height), Image.ANTIALIAS) + img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) + + # Create a new image with the target size and place the resized image in the center + delta_w = w - img.size[0] + delta_h = h - img.size[1] + padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2)) + new_img = ImageOps.expand(img, padding, fill=(255, 255, 255)) + + # Transform to tensor and normalize. + transform = Compose( + [ + ToTensor(), + Normalize(0.5, 0.5), + ] + ) + new_img = transform(new_img) + img_vae_latent = self.vae.encode([new_img.unsqueeze(1)], device) + + # Vae encode. + return img_vae_latent + + def get_audio_emb_window(self, audio_emb, frame_num, frame0_idx, audio_shift=2): + zero_audio_embed = torch.zeros((audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device) + zero_audio_embed_3 = torch.zeros((3, audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device) # device=audio_emb.device + iter_ = 1 + (frame_num - 1) // 4 + audio_emb_wind = [] + for lt_i in range(iter_): + if lt_i == 0: + st = frame0_idx + lt_i - 2 + ed = frame0_idx + lt_i + 3 + wind_feat = torch.stack([ + audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed + for i in range(st, ed) + ], dim=0) + wind_feat = torch.cat((zero_audio_embed_3, wind_feat), dim=0) + else: + st = frame0_idx + 1 + 4 * (lt_i - 1) - audio_shift + ed = frame0_idx + 1 + 4 * lt_i + audio_shift + wind_feat = torch.stack([ + audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed + for i in range(st, ed) + ], dim=0) + audio_emb_wind.append(wind_feat) + audio_emb_wind = torch.stack(audio_emb_wind, dim=0) + + return audio_emb_wind, ed - audio_shift + + def audio_emb_enc(self, audio_emb, wav_enc_type="whisper"): + if wav_enc_type == "wav2vec": + feat_merge = audio_emb + elif wav_enc_type == "whisper": + feat0 = linear_interpolation_fps(audio_emb[:, :, 0: 8].mean(dim=2), 50, 25) + feat1 = linear_interpolation_fps(audio_emb[:, :, 8: 16].mean(dim=2), 50, 25) + feat2 = linear_interpolation_fps(audio_emb[:, :, 16: 24].mean(dim=2), 50, 25) + feat3 = linear_interpolation_fps(audio_emb[:, :, 24: 32].mean(dim=2), 50, 25) + feat4 = linear_interpolation_fps(audio_emb[:, :, 32], 50, 25) + feat_merge = torch.stack([feat0, feat1, feat2, feat3, feat4], dim=2)[0] + else: + raise ValueError(f"Unsupported wav_enc_type: {wav_enc_type}") + + return feat_merge + + def forward_tia(self, latents, latents_ref, latents_ref_neg, timestep, arg_t, arg_ta, arg_null): + neg = self.dit( + [torch.cat([latent[:,:-latent_ref_neg.shape[1]], latent_ref_neg], dim=1) for latent, latent_ref_neg in zip(latents, latents_ref_neg)], t=timestep, **arg_null + )[0] + + pos_t = self.dit( + [torch.cat([latent[:,:-latent_ref_neg.shape[1]], latent_ref_neg], dim=1) for latent, latent_ref_neg in zip(latents, latents_ref_neg)], t=timestep, **arg_t + )[0] + pos_ta = self.dit( + [torch.cat([latent[:,:-latent_ref_neg.shape[1]], latent_ref_neg], dim=1) for latent, latent_ref_neg in zip(latents, latents_ref_neg)], t=timestep, **arg_ta + )[0] + pos_tia = self.dit( + [torch.cat([latent[:,:-latent_ref.shape[1]], latent_ref], dim=1) for latent, latent_ref in zip(latents, latents_ref)], t=timestep, **arg_ta + )[0] + + noise_pred = self.config.generation.scale_i * (pos_tia - pos_ta) + \ + self.config.generation.scale_a * (pos_ta - pos_t) + \ + self.config.generation.scale_t * (pos_t - neg) + \ + neg + + return noise_pred + + def forward_ta(self, latents, latents_ref_neg, timestep, arg_t, arg_ta, arg_null): + neg = self.dit( + [torch.cat([latent[:,:-latent_ref_neg.shape[1]], latent_ref_neg], dim=1) for latent, latent_ref_neg in zip(latents, latents_ref_neg)], t=timestep, **arg_null + )[0] + + pos_t = self.dit( + [torch.cat([latent[:,:-latent_ref_neg.shape[1]], latent_ref_neg], dim=1) for latent, latent_ref_neg in zip(latents, latents_ref_neg)], t=timestep, **arg_t + )[0] + pos_ta = self.dit( + [torch.cat([latent[:,:-latent_ref_neg.shape[1]], latent_ref_neg], dim=1) for latent, latent_ref_neg in zip(latents, latents_ref_neg)], t=timestep, **arg_ta + )[0] + + noise_pred = self.config.generation.scale_a * (pos_ta - pos_t) + \ + self.config.generation.scale_t * (pos_t - neg) + \ + neg + + return noise_pred + + + @torch.no_grad() + def inference(self, + input_prompt, + img_path, + audio_path, + size=(1280, 720), + frame_num=81, + shift=5.0, + sample_solver='unipc', + sampling_steps=50, + n_prompt="", + seed=-1, + offload_model=True, + device = get_device(), + ): + + self.vae.model.to(device=device) + if img_path is not None: + latents_ref = self.load_image_latent_ref_id(img_path, size, device) + else: + latents_ref = [torch.zeros(16, 1, size[1]//8, size[0]//8).to(device)] + + self.vae.model.to(device="cpu") + latents_ref_neg = [torch.zeros_like(latent_ref) for latent_ref in latents_ref] + + # audio + if audio_path is not None: + if self.config.generation.extract_audio_feat: + self.audio_processor.whisper.to(device=device) + audio_emb, audio_length = self.audio_processor.preprocess(audio_path) + self.audio_processor.whisper.to(device='cpu') + else: + audio_emb_path = audio_path.replace(".wav", ".pt") + audio_emb = torch.load(audio_emb_path).to(device=device) + audio_emb = self.audio_emb_enc(audio_emb, wav_enc_type="whisper") + self.logger.info("使用预先提取好的音频特征: %s", audio_emb_path) + else: + audio_emb = torch.zeros(frame_num, 5, 1280).to(device) + + frame_num = frame_num if frame_num != -1 else audio_length + frame_num = 4 * ((frame_num - 1) // 4) + 1 + audio_emb, _ = self.get_audio_emb_window(audio_emb, frame_num, frame0_idx=0) + zero_audio_pad = torch.zeros(latents_ref[0].shape[1], *audio_emb.shape[1:]).to(audio_emb.device) + audio_emb = torch.cat([audio_emb, zero_audio_pad], dim=0) + audio_emb = [audio_emb.to(device)] + audio_emb_neg = [torch.zeros_like(audio_emb[0])] + + # preprocess + self.patch_size = self.config.dit.model.patch_size + F = frame_num + target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1 + latents_ref[0].shape[1], + size[1] // self.vae_stride[1], + size[0] // self.vae_stride[2]) + + seq_len = math.ceil((target_shape[2] * target_shape[3]) / + (self.patch_size[1] * self.patch_size[2]) * + target_shape[1] / self.sp_size) * self.sp_size + + if n_prompt == "": + n_prompt = self.config.generation.sample_neg_prompt + seed = seed if seed >= 0 else random.randint(0, sys.maxsize) + seed_g = torch.Generator(device=device) + seed_g.manual_seed(seed) + + self.text_encoder.model.to(device) + context = self.text_encoder([input_prompt], device) + context_null = self.text_encoder([n_prompt], device) + self.text_encoder.model.cpu() + + noise = [ + torch.randn( + target_shape[0], + target_shape[1], # - latents_ref[0].shape[1], + target_shape[2], + target_shape[3], + dtype=torch.float32, + device=device, + generator=seed_g) + ] + + @contextmanager + def noop_no_sync(): + yield + + no_sync = getattr(self.dit, 'no_sync', noop_no_sync) + # step_change = self.config.generation.step_change # 980 + + # evaluation mode + with amp.autocast(dtype=torch.bfloat16), torch.no_grad(), no_sync(): + + if sample_solver == 'unipc': + sample_scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=1000, + shift=1, + use_dynamic_shifting=False) + sample_scheduler.set_timesteps( + sampling_steps, device=device, shift=shift) + timesteps = sample_scheduler.timesteps + + # sample videos + latents = noise + + # referene image在下面的输入中手动指定, 不在arg中指定 + arg_ta = {'context': context, 'seq_len': seq_len, 'audio': audio_emb} + arg_t = {'context': context, 'seq_len': seq_len, 'audio': audio_emb_neg} + arg_null = {'context': context_null, 'seq_len': seq_len, 'audio': audio_emb_neg} + + torch.cuda.empty_cache() + self.dit.to(device=get_device()) + for _, t in enumerate(tqdm(timesteps)): + timestep = [t] + timestep = torch.stack(timestep) + + if self.config.generation.mode == "TIA": + noise_pred = self.forward_tia(latents, latents_ref, latents_ref_neg, timestep, arg_t, arg_ta, arg_null) + elif self.config.generation.mode == "TA": + noise_pred = self.forward_ta(latents, latents_ref_neg, timestep, arg_t, arg_ta, arg_null) + else: + raise ValueError(f"Unsupported generation mode: {self.config.generation.mode}") + + temp_x0 = sample_scheduler.step( + noise_pred.unsqueeze(0), + t, + latents[0].unsqueeze(0), + return_dict=False, + generator=seed_g)[0] + latents = [temp_x0.squeeze(0)] + + del timestep + torch.cuda.empty_cache() + + x0 = latents + x0 = [x0_[:,:-latents_ref[0].shape[1]] for x0_ in x0] + + # if offload_model: + self.dit.cpu() + torch.cuda.empty_cache() + # if get_local_rank() == 0: + self.vae.model.to(device=device) + videos = self.vae.decode(x0) + self.vae.model.to(device="cpu") + + del noise, latents, noise_pred + del audio_emb, audio_emb_neg, latents_ref, latents_ref_neg, context, context_null + del x0, temp_x0 + del sample_scheduler + torch.cuda.empty_cache() + gc.collect() + torch.cuda.synchronize() + if dist.is_initialized(): + dist.barrier() + + return videos[0] # if get_local_rank() == 0 else None + + + def inference_loop(self, prompt, ref_img_path, audio_path, output_dir, filename, width = 832, height = 480, steps=50, frames = 97, seed = 0): + print(f'ref_img_path:{ref_img_path}') + + video = self.inference( + prompt, + ref_img_path, + audio_path, + size=SIZE_CONFIGS[f"{width}*{height}"], + frame_num=frames, + shift=self.config.diffusion.timesteps.sampling.shift, + sample_solver='unipc', + sampling_steps=steps, + seed=seed, + offload_model=False, + ) + + torch.cuda.empty_cache() + gc.collect() + + + # Save samples. + if get_sequence_parallel_rank() == 0: + pathname = self.save_sample( + sample=video, + audio_path=audio_path, + output_dir = output_dir, + filename=filename, + ) + self.logger.info(f"Finished {filename}, saved to {pathname}.") + + del video, prompt + torch.cuda.empty_cache() + gc.collect() + + + + def save_sample(self, *, sample: torch.Tensor, audio_path: str, output_dir: str, filename: str): + gen_config = self.config.generation + # Prepare file path. + extension = ".mp4" if sample.ndim == 4 else ".png" + filename += extension + pathname = os.path.join(output_dir, filename) + # Convert sample. + sample = sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).to("cpu", torch.uint8) + sample = rearrange(sample, "c t h w -> t h w c") + # Save file. + if sample.ndim == 4: + if audio_path is not None: + tensor_to_video( + sample.numpy(), + pathname, + audio_path, + fps=gen_config.fps) + else: + mediapy.write_video( + path=pathname, + images=sample.numpy(), + fps=gen_config.fps, + ) + else: + raise ValueError + return pathname + + + def prepare_positive_prompts(self): + pos_prompts = self.config.generation.positive_prompt + if pos_prompts.endswith(".json"): + pos_prompts = prepare_json_dataset(pos_prompts) + else: + raise NotImplementedError + assert isinstance(pos_prompts, ListConfig) + + return pos_prompts \ No newline at end of file diff --git a/humo/models/audio/audio_proj.py b/humo/models/audio/audio_proj.py new file mode 100644 index 0000000000000000000000000000000000000000..2d4771dad1e648fb063c30a18258d258a4739dc4 --- /dev/null +++ b/humo/models/audio/audio_proj.py @@ -0,0 +1,87 @@ +import torch +from einops import rearrange +from torch import nn +from einops import rearrange + +class WanRMSNorm(nn.Module): + + def __init__(self, dim, eps=1e-5): + super().__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + r""" + Args: + x(Tensor): Shape [B, L, C] + """ + return self._norm(x.float()).type_as(x) * self.weight + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) + + +class DummyAdapterLayer(nn.Module): + def __init__(self, layer): + super().__init__() + self.layer = layer + + def forward(self, *args, **kwargs): + return self.layer(*args, **kwargs) + + +class AudioProjModel(nn.Module): + def __init__( + self, + seq_len=5, + blocks=13, # add a new parameter blocks + channels=768, # add a new parameter channels + intermediate_dim=512, + output_dim=1536, + context_tokens=16, + ): + super().__init__() + + self.seq_len = seq_len + self.blocks = blocks + self.channels = channels + self.input_dim = seq_len * blocks * channels # update input_dim to be the product of blocks and channels. + self.intermediate_dim = intermediate_dim + self.context_tokens = context_tokens + self.output_dim = output_dim + + # define multiple linear layers + self.audio_proj_glob_1 = DummyAdapterLayer(nn.Linear(self.input_dim, intermediate_dim)) + self.audio_proj_glob_2 = DummyAdapterLayer(nn.Linear(intermediate_dim, intermediate_dim)) + self.audio_proj_glob_3 = DummyAdapterLayer(nn.Linear(intermediate_dim, context_tokens * output_dim)) + + self.audio_proj_glob_norm = DummyAdapterLayer(nn.LayerNorm(output_dim)) + + self.initialize_weights() + + def initialize_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + def forward(self, audio_embeds): + video_length = audio_embeds.shape[1] + audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c") + batch_size, window_size, blocks, channels = audio_embeds.shape + audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels) + + audio_embeds = torch.relu(self.audio_proj_glob_1(audio_embeds)) + audio_embeds = torch.relu(self.audio_proj_glob_2(audio_embeds)) + + context_tokens = self.audio_proj_glob_3(audio_embeds).reshape(batch_size, self.context_tokens, self.output_dim) + + context_tokens = self.audio_proj_glob_norm(context_tokens) + context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length) + + return context_tokens \ No newline at end of file diff --git a/humo/models/distributed/__init__.py b/humo/models/distributed/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/humo/models/distributed/dit_ulysses_sequence_parallel.py b/humo/models/distributed/dit_ulysses_sequence_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..3da02c08b488bd5130e9879ed024b64ae09287d5 --- /dev/null +++ b/humo/models/distributed/dit_ulysses_sequence_parallel.py @@ -0,0 +1,270 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from einops import rearrange +from common.distributed import get_device + +from common.distributed.advanced import ( + get_unified_parallel_world_size, + get_unified_parallel_group, + pad_tensor, + Slice, + gather_outputs, + gather_seq_scatter_heads_qkv, + gather_seq_scatter_double_head, + gather_heads_scatter_seq, + unpad_tensor +) +from humo.models.wan_modules.attention import flash_attention +from humo.models.wan_modules.model_humo import rope_apply, sinusoidal_embedding_1d + + +def ulysses_dit_forward( + self, + x, + t, + context, + seq_len, + audio=None, + y=None +): + """ + x: A list of videos each with shape [C, T, H, W]. + t: [B]. + context: A list of text embeddings each with shape [L, C]. + """ + if self.model_type == 'i2v': + # assert clip_fea is not None and y is not None + assert y is not None + # params + device = self.patch_embedding.weight.device + if self.freqs.device != device: + self.freqs = self.freqs.to(device) + + if y is not None: + x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] + + # embeddings + x = [self.patch_embedding(u.unsqueeze(0)) for u in x] + grid_sizes = torch.stack( + [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) + x = [u.flatten(2).transpose(1, 2) for u in x] + seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long, device=device) + + assert seq_lens.max() <= seq_len + x = torch.cat([ + torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) + for u in x + ]) + + # time embeddings + with torch.amp.autocast('cuda', dtype=torch.float32): + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t).float()).float() + e0 = self.time_projection(e).unflatten(1, (6, self.dim)).float() + assert e.dtype == torch.float32 and e0.dtype == torch.float32 + + # context + context_lens = None + context = self.text_embedding( + torch.stack([ + torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) + for u in context + ])) + + if self.insert_audio: + audio = [self.audio_proj(au.unsqueeze(0)).permute(0, 3, 1, 2) for au in audio] + + audio_seq_len = torch.tensor(max([au.shape[2] for au in audio]) * audio[0].shape[3], device=get_device()) + audio = [au.flatten(2).transpose(1, 2) for au in audio] # [1, t*32, 1536] + audio_seq_lens = torch.tensor([au.size(1) for au in audio], dtype=torch.long, device=device) + audio = torch.cat([ + torch.cat([au, au.new_zeros(1, audio_seq_len - au.size(1), au.size(2))], + dim=1) for au in audio + ]) + else: + audio = None + audio_seq_len = None + audio_seq_lens = None + + # ulysses support + sp_world = get_unified_parallel_world_size() + group = get_unified_parallel_group() + if seq_len % sp_world: + padding_size = sp_world - (seq_len % sp_world) + x = pad_tensor(x, dim=1, padding_size=padding_size) + + if self.insert_audio: + audio_padding_size = sp_world - (audio_seq_len % sp_world) + audio = pad_tensor(audio, dim=1, padding_size=audio_padding_size) + + x = Slice.apply(group, x, 1, True) + + if self.insert_audio: + audio = Slice.apply(group, audio, 1, True) + + # arguments + kwargs = dict( + e=e0, + seq_lens=seq_lens, + grid_sizes=grid_sizes, + freqs=self.freqs, + context=context, + context_lens=context_lens, + audio=audio, + audio_seq_len=audio_seq_len) + + for block in self.blocks: + x = block(x, **kwargs) + + # head + x = self.head(x, e) + + # ulysses support + x = gather_outputs(x, gather_dim=1, padding_dim=1, unpad_dim_size=seq_len, scale_grad=True) + + # unpatchify + x = self.unpatchify(x, grid_sizes) + return [u.float() for u in x] + + +def ulysses_attn_forward( + self, + x, + seq_lens, + grid_sizes, + freqs, + dtype=torch.bfloat16 +): + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim + seq_len = seq_lens.max() + half_dtypes = (torch.float16, torch.bfloat16) + + def half(x): + return x if x.dtype in half_dtypes else x.to(dtype) + + # query, key, value function + def qkv_fn(x): + q = self.norm_q(self.q(x)) + k = self.norm_k(self.k(x)) + v = self.v(x) + return q, k, v + + q, k, v = qkv_fn(x) + + # ulysses support + sp_size = get_unified_parallel_world_size() + if n % sp_size: + pad_size = sp_size - (n % sp_size) + pad_size = pad_size * d + pad_inner_dim = n * d + pad_size + q = pad_tensor(q, dim=2, padding_size=pad_size) + k = pad_tensor(k, dim=2, padding_size=pad_size) + v = pad_tensor(v, dim=2, padding_size=pad_size) + else: + pad_inner_dim = n * d + + qkv = torch.cat([q, k, v], dim=2) + qkv = gather_seq_scatter_heads_qkv(qkv, seq_dim=1, unpadded_dim_size=seq_len) + q, k, v = qkv.split(pad_inner_dim // sp_size, dim=2) + + pad_n = pad_inner_dim // d + pad_split_n = pad_n // sp_size + q = q.view(b, seq_len, pad_split_n, d) + k = k.view(b, seq_len, pad_split_n, d) + v = v.view(b, seq_len, pad_split_n, d) + + q = rope_apply(q, grid_sizes, freqs) + k = rope_apply(k, grid_sizes, freqs) + + x = flash_attention( + q=half(q), + k=half(k), + v=half(v), + k_lens=seq_lens, + window_size=self.window_size + ) + + # ulysses support + x = x.flatten(2) + x = gather_heads_scatter_seq(x, head_dim=2, seq_dim=1) + if n % sp_size: + x = unpad_tensor(x, dim=2, unpad_dim_size=seq_len) + + x = self.o(x) + return x + + +def ulysses_audio_cross_attn_forward( + self, + x, + audio, + seq_lens, + grid_sizes, + freqs, + audio_seq_len, + dtype=torch.bfloat16 +): + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim + seq_len = seq_lens.max() + + q = self.norm_q(self.q(x)) + k = self.norm_k(self.k(audio)) + v = self.v(audio) + + # ulysses support + sp_size = get_unified_parallel_world_size() + if n % sp_size: + pad_size = sp_size - (n % sp_size) + pad_size = pad_size * d + pad_inner_dim = n * d + pad_size + q = pad_tensor(q, dim=2, padding_size=pad_size) + k = pad_tensor(k, dim=2, padding_size=pad_size) + v = pad_tensor(v, dim=2, padding_size=pad_size) + else: + pad_inner_dim = n * d + + qq = torch.cat([q, q], dim=2) + kv = torch.cat([k, v], dim=2) + qq = gather_seq_scatter_double_head(qq, seq_dim=1, unpadded_dim_size=seq_len) + kv = gather_seq_scatter_double_head(kv, seq_dim=1, unpadded_dim_size=audio_seq_len) + q, _ = qq.split(pad_inner_dim // sp_size, dim=2) + k, v = kv.split(pad_inner_dim // sp_size, dim=2) + + pad_n = pad_inner_dim // d + pad_split_n = pad_n // sp_size + q = q.view(b, seq_len, pad_split_n, d) + k = k.view(b, audio_seq_len, pad_split_n, d) + v = v.view(b, audio_seq_len, pad_split_n, d) + + hlen_wlen = int(grid_sizes[0][1] * grid_sizes[0][2]) + assert hlen_wlen == 1560 or hlen_wlen == 3600 + q = q.reshape(-1, hlen_wlen, pad_split_n, d) + k = k.reshape(-1, 16, pad_split_n, d) + v = v.reshape(-1, 16, pad_split_n, d) + + x = flash_attention( + q=q, + k=k, + v=v, + k_lens=None, + ) + x = x.view(b, -1, pad_split_n, d) + + # ulysses support + x = x.flatten(2) + x = gather_heads_scatter_seq(x, head_dim=2, seq_dim=1) + if n % sp_size: + x = unpad_tensor(x, dim=2, unpad_dim_size=seq_len) + + x = self.o(x) + return x diff --git a/humo/models/distributed/fsdp.py b/humo/models/distributed/fsdp.py new file mode 100644 index 0000000000000000000000000000000000000000..51c81e7700577e07ef1a2ebf7f26cc02ad269cd7 --- /dev/null +++ b/humo/models/distributed/fsdp.py @@ -0,0 +1,42 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial + +import torch +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import MixedPrecision, ShardingStrategy +from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy + + +def shard_model( + model, + device_id, + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + buffer_dtype=torch.float32, + process_group=None, + sharding_strategy=ShardingStrategy.FULL_SHARD, + sync_module_states=True, +): + model = FSDP( + module=model, + process_group=process_group, + sharding_strategy=sharding_strategy, + auto_wrap_policy=partial( + lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks), + mixed_precision=MixedPrecision( + param_dtype=param_dtype, + reduce_dtype=reduce_dtype, + buffer_dtype=buffer_dtype), + device_id=device_id, + sync_module_states=sync_module_states) + return model diff --git a/humo/models/text/encoder.py b/humo/models/text/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..4b444bb7e778936be7f874947942780bcf3cfcfc --- /dev/null +++ b/humo/models/text/encoder.py @@ -0,0 +1,173 @@ +import os +from dataclasses import dataclass +from typing import List, Optional, Union +import torch +from omegaconf import DictConfig, OmegaConf +from torch import nn +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + CLIPTextModel, + CLIPTokenizerFast, + T5EncoderModel, + T5TokenizerFast, +) +from transformers.tokenization_utils_base import BatchEncoding + +from common.fs import download_and_extract +from common.logger import get_logger + +logger = get_logger(__name__) + +MODEL_TYPES = { + "clip": (CLIPTokenizerFast, CLIPTextModel), + "t5": (T5TokenizerFast, T5EncoderModel), + "llm14b": (AutoTokenizer, AutoModelForCausalLM), +} + + +@dataclass +class TextEncoderOutput: + embeddings: Union[torch.FloatTensor, List[torch.FloatTensor]] + masks: Union[torch.BoolTensor, List[torch.BoolTensor]] + pooled: Optional[Union[torch.FloatTensor, List[torch.FloatTensor]]] + + +class TextEncoder(nn.Module): + def __init__(self, config: DictConfig): + super().__init__() + self.config = config + self.tokenizers = [] + self.models = nn.ModuleList([]) + + # Disable tokenizer parallelism since we already use distributed training. + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + for model in config.models: + tokenizer_cls, model_cls = MODEL_TYPES[model.type] + path = download_and_extract(model.path) + max_length = model.max_length + + if model.type == "llm14b": + tokenizer = tokenizer_cls.from_pretrained( + path, + model_max_length=max_length, + use_fast=False, + trust_remote_code=True, + padding_side="right", + truncation_side="right", + add_eod_token=True, + ) + tokenizer.add_special_tokens({"pad_token": "<|endoftext|>"}) + model = model_cls.from_pretrained(path, trust_remote_code=True, bf16=True) + else: + tokenizer = tokenizer_cls.from_pretrained(path, model_max_length=max_length) + model = model_cls.from_pretrained(path, torch_dtype=torch.bfloat16) + self.tokenizers.append(tokenizer) + self.models.append(model) + + def forward(self, text: Union[str, List[str]]) -> TextEncoderOutput: + embeddings, masks, pooled = [], [], [] + + for encoder_config, tokenizer, model in zip( + self.config.models, self.tokenizers, self.models + ): + if encoder_config.type == "llm14b": + use_mask = encoder_config.get("mask", True) + tokens = tokenizer( + text, + return_tensors="pt", + padding="max_length", + truncation=True, + ).to(model.device) + token_ids = tokens["input_ids"] + attention_mask = tokens["attention_mask"] + num_tokens = attention_mask.sum(dim=1) + range_ids = torch.arange(len(token_ids), device=token_ids.device, dtype=torch.long) + token_ids[range_ids, num_tokens.clamp(max=token_ids.size(1) - 1)] = ( + tokenizer.pad_token_id + ) + attention_mask[range_ids, num_tokens.clamp(max=token_ids.size(1) - 1)] = 1 + tokens = BatchEncoding({"input_ids": token_ids, "attention_mask": attention_mask}) + output = model.transformer( + input_ids=tokens.input_ids, + attention_mask=attention_mask if use_mask else None, + output_hidden_states=False, + use_cache=False, + ) + emb = output.last_hidden_state # batch_size, num_tokens, feat_dim + # emb *= tokens.attention_mask.unsqueeze(-1) + + embeddings.append(emb) + masks.append( + tokens.attention_mask.bool() if use_mask else tokens.attention_mask > -1 + ) + + else: + # Tokenizer + tokens = tokenizer( + text=text, + truncation=True, + padding="max_length", + return_tensors="pt", + ) + + # Encoder + use_mask = encoder_config.get("mask", True) + input_ids = tokens.input_ids.to(model.device) + attention_mask = tokens.attention_mask.to(model.device) + output = model( + input_ids=input_ids, + attention_mask=attention_mask if use_mask else None, + output_hidden_states=True, + ) + + # Save embeddings from the defined layer. + layer = encoder_config.get("layer", "last") + if layer == "last": + embeddings.append(output.last_hidden_state) + elif layer == "penultimate": + embeddings.append(model.text_model.final_layer_norm(output.hidden_states[-2])) + elif layer == "penultimate_nonorm": + embeddings.append(output.hidden_states[-2]) + else: + raise NotImplementedError(f"Unknown layer type: {layer}.") + + # Save masks + masks.append(attention_mask.bool() if use_mask else attention_mask > -1) + + # Save pooled output if available. + if hasattr(output, "pooler_output"): + pooled.append(output.pooler_output) + + output_config = self.config.get("output") or OmegaConf.create() + embedding_output_type = output_config.get("embedding_and_mask", "undefined") + pooled_output_type = output_config.get("pooled", "undefined") + + # Select or merge embeddings and mask if needed. + if embedding_output_type == "undefined" and len(self.models) == 1: + embeddings = embeddings[0] + masks = masks[0] + elif embedding_output_type == "channel_concat": + embeddings = torch.cat(embeddings, dim=-1) + masks = sum(masks).bool() + elif embedding_output_type == "last": + embeddings = embeddings[-1] + masks = masks[-1] + else: + raise NotImplementedError(f"output.embedding_and_mask: {embedding_output_type}") + + # Select or merge pooled output if needed. + if pooled_output_type == "undefined": + pooled = None + elif pooled_output_type == "channel_concat": + pooled = torch.cat(pooled, dim=-1) + elif pooled_output_type == "first": + pooled = pooled[0] + elif pooled_output_type == "last": + pooled = pooled[-1] + else: + raise NotImplementedError(f"output.pooled: {pooled_output_type}") + + # Return final results. + return TextEncoderOutput(embeddings, masks, pooled) diff --git a/humo/models/utils/fm_solvers.py b/humo/models/utils/fm_solvers.py new file mode 100644 index 0000000000000000000000000000000000000000..1271b678bd0b1287adebbc344ef5fbfa952a9558 --- /dev/null +++ b/humo/models/utils/fm_solvers.py @@ -0,0 +1,857 @@ +# Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +# Convert dpm solver for flow matching +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. + +import inspect +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import (KarrasDiffusionSchedulers, + SchedulerMixin, + SchedulerOutput) +from diffusers.utils import deprecate, is_scipy_available +from diffusers.utils.torch_utils import randn_tensor + +if is_scipy_available(): + pass + + +def get_sampling_sigmas(sampling_steps, shift): + sigma = np.linspace(1, 0, sampling_steps + 1)[:sampling_steps] + sigma = (shift * sigma / (1 + (shift - 1) * sigma)) + + return sigma + + +def retrieve_timesteps( + scheduler, + num_inference_steps=None, + device=None, + timesteps=None, + sigmas=None, + **kwargs, +): + if timesteps is not None and sigmas is not None: + raise ValueError( + "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" + ) + if timesteps is not None: + accepts_timesteps = "timesteps" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class FlowDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): + """ + `FlowDPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs. + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. This determines the resolution of the diffusion process. + solver_order (`int`, defaults to 2): + The DPMSolver order which can be `1`, `2`, or `3`. It is recommended to use `solver_order=2` for guided + sampling, and `solver_order=3` for unconditional sampling. This affects the number of model outputs stored + and used in multistep updates. + prediction_type (`str`, defaults to "flow_prediction"): + Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts + the flow of the diffusion process. + shift (`float`, *optional*, defaults to 1.0): + A factor used to adjust the sigmas in the noise schedule. It modifies the step sizes during the sampling + process. + use_dynamic_shifting (`bool`, defaults to `False`): + Whether to apply dynamic shifting to the timesteps based on image resolution. If `True`, the shifting is + applied on the fly. + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This method adjusts the predicted sample to prevent + saturation and improve photorealism. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True` and + `algorithm_type="dpmsolver++"`. + algorithm_type (`str`, defaults to `dpmsolver++`): + Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The + `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) + paper, and the `dpmsolver++` type implements the algorithms in the + [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or + `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion. + solver_type (`str`, defaults to `midpoint`): + Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the + sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers. + lower_order_final (`bool`, defaults to `True`): + Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can + stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. + euler_at_final (`bool`, defaults to `False`): + Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail + richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference + steps, but sometimes may result in blurring. + final_sigmas_type (`str`, *optional*, defaults to "zero"): + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final + sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + lambda_min_clipped (`float`, defaults to `-inf`): + Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the + cosine (`squaredcos_cap_v2`) noise schedule. + variance_type (`str`, *optional*): + Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output + contains the predicted Gaussian variance. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + solver_order: int = 2, + prediction_type: str = "flow_prediction", + shift: Optional[float] = 1.0, + use_dynamic_shifting=False, + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + algorithm_type: str = "dpmsolver++", + solver_type: str = "midpoint", + lower_order_final: bool = True, + euler_at_final: bool = False, + final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" + lambda_min_clipped: float = -float("inf"), + variance_type: Optional[str] = None, + invert_sigmas: bool = False, + ): + if algorithm_type in ["dpmsolver", "sde-dpmsolver"]: + deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" + deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", + deprecation_message) + + # settings for DPM-Solver + if algorithm_type not in [ + "dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++" + ]: + if algorithm_type == "deis": + self.register_to_config(algorithm_type="dpmsolver++") + else: + raise NotImplementedError( + f"{algorithm_type} is not implemented for {self.__class__}") + + if solver_type not in ["midpoint", "heun"]: + if solver_type in ["logrho", "bh1", "bh2"]: + self.register_to_config(solver_type="midpoint") + else: + raise NotImplementedError( + f"{solver_type} is not implemented for {self.__class__}") + + if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++" + ] and final_sigmas_type == "zero": + raise ValueError( + f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead." + ) + + # setable values + self.num_inference_steps = None + alphas = np.linspace(1, 1 / num_train_timesteps, + num_train_timesteps)[::-1].copy() + sigmas = 1.0 - alphas + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32) + + if not use_dynamic_shifting: + # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution + sigmas = shift * sigmas / (1 + + (shift - 1) * sigmas) # pyright: ignore + + self.sigmas = sigmas + self.timesteps = sigmas * num_train_timesteps + + self.model_outputs = [None] * solver_order + self.lower_order_nums = 0 + self._step_index = None + self._begin_index = None + + # self.sigmas = self.sigmas.to( + # "cpu") # to avoid too much CPU/GPU communication + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps + def set_timesteps( + self, + num_inference_steps: Union[int, None] = None, + device: Union[str, torch.device] = None, + sigmas: Optional[List[float]] = None, + mu: Optional[Union[float, None]] = None, + shift: Optional[Union[float, None]] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + Args: + num_inference_steps (`int`): + Total number of the spacing of the time steps. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + + if self.config.use_dynamic_shifting and mu is None: + raise ValueError( + " you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`" + ) + + if sigmas is None: + sigmas = np.linspace(self.sigma_max, self.sigma_min, + num_inference_steps + + 1).copy()[:-1] # pyright: ignore + + if self.config.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore + else: + if shift is None: + shift = self.config.shift + sigmas = shift * sigmas / (1 + + (shift - 1) * sigmas) # pyright: ignore + + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ((1 - self.alphas_cumprod[0]) / + self.alphas_cumprod[0])**0.5 + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + + timesteps = sigmas * self.config.num_train_timesteps + sigmas = np.concatenate([sigmas, [sigma_last] + ]).astype(np.float32) # pyright: ignore + + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(timesteps).to( + device=device, dtype=torch.int64) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + + self._step_index = None + self._begin_index = None + # self.sigmas = self.sigmas.to( + # "cpu") # to avoid too much CPU/GPU communication + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float( + ) # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile( + abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze( + 1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp( + sample, -s, s + ) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + def _sigma_to_alpha_sigma_t(self, sigma): + return 1 - sigma, sigma + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma) + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output + def convert_model_output( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + """ + Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is + designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an + integral of the data prediction model. + + The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise + prediction and data prediction models. + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + Returns: + `torch.Tensor`: + The converted model output. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError( + "missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + # DPM-Solver++ needs to solve an integral of the data prediction model. + if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," + " `v_prediction`, or `flow_prediction` for the FlowDPMSolverMultistepScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + + # DPM-Solver needs to solve an integral of the noise prediction model. + elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + epsilon = sample - (1 - sigma_t) * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," + " `v_prediction` or `flow_prediction` for the FlowDPMSolverMultistepScheduler." + ) + + if self.config.thresholding: + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + x0_pred = self._threshold_sample(x0_pred) + epsilon = model_output + x0_pred + + return epsilon + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.dpm_solver_first_order_update + def dpm_solver_first_order_update( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the first-order DPMSolver (equivalent to DDIM). + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop( + "prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError( + " missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[ + self.step_index] # pyright: ignore + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s = torch.log(alpha_s) - torch.log(sigma_s) + + h = lambda_t - lambda_s + if self.config.algorithm_type == "dpmsolver++": + x_t = (sigma_t / + sigma_s) * sample - (alpha_t * + (torch.exp(-h) - 1.0)) * model_output + elif self.config.algorithm_type == "dpmsolver": + x_t = (alpha_t / + alpha_s) * sample - (sigma_t * + (torch.exp(h) - 1.0)) * model_output + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + x_t = ((sigma_t / sigma_s * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + x_t = ((alpha_t / alpha_s) * sample - 2.0 * + (sigma_t * (torch.exp(h) - 1.0)) * model_output + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise) + return x_t # pyright: ignore + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_second_order_update + def multistep_dpm_solver_second_order_update( + self, + model_output_list: List[torch.Tensor], + *args, + sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the second-order multistep DPMSolver. + Args: + model_output_list (`List[torch.Tensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + timestep_list = args[0] if len(args) > 0 else kwargs.pop( + "timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop( + "prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError( + " missing `sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1 = ( + self.sigmas[self.step_index + 1], # pyright: ignore + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], # pyright: ignore + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + + m0, m1 = model_output_list[-1], model_output_list[-2] + + h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 + r0 = h_0 / h + D0, D1 = m0, (1.0 / r0) * (m0 - m1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2211.01095 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ((sigma_t / sigma_s0) * sample - + (alpha_t * (torch.exp(-h) - 1.0)) * D0 - 0.5 * + (alpha_t * (torch.exp(-h) - 1.0)) * D1) + elif self.config.solver_type == "heun": + x_t = ((sigma_t / sigma_s0) * sample - + (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ((alpha_t / alpha_s0) * sample - + (sigma_t * (torch.exp(h) - 1.0)) * D0 - 0.5 * + (sigma_t * (torch.exp(h) - 1.0)) * D1) + elif self.config.solver_type == "heun": + x_t = ((alpha_t / alpha_s0) * sample - + (sigma_t * (torch.exp(h) - 1.0)) * D0 - + (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1) + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ((sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + 0.5 * + (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise) + elif self.config.solver_type == "heun": + x_t = ((sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / + (-2.0 * h) + 1.0)) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ((alpha_t / alpha_s0) * sample - 2.0 * + (sigma_t * (torch.exp(h) - 1.0)) * D0 - + (sigma_t * (torch.exp(h) - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise) + elif self.config.solver_type == "heun": + x_t = ((alpha_t / alpha_s0) * sample - 2.0 * + (sigma_t * (torch.exp(h) - 1.0)) * D0 - 2.0 * + (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise) + return x_t # pyright: ignore + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_third_order_update + def multistep_dpm_solver_third_order_update( + self, + model_output_list: List[torch.Tensor], + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the third-order multistep DPMSolver. + Args: + model_output_list (`List[torch.Tensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + sample (`torch.Tensor`): + A current instance of a sample created by diffusion process. + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + + timestep_list = args[0] if len(args) > 0 else kwargs.pop( + "timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop( + "prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError( + " missing`sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1, sigma_s2 = ( + self.sigmas[self.step_index + 1], # pyright: ignore + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], # pyright: ignore + self.sigmas[self.step_index - 2], # pyright: ignore + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2) + + m0, m1, m2 = model_output_list[-1], model_output_list[ + -2], model_output_list[-3] + + h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 + r0, r1 = h_0 / h, h_1 / h + D0 = m0 + D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) + D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) + D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ((sigma_t / sigma_s0) * sample - + (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 - + (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ((alpha_t / alpha_s0) * sample - (sigma_t * + (torch.exp(h) - 1.0)) * D0 - + (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 - + (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2) + return x_t # pyright: ignore + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + def _init_step_index(self, timestep): + """ + Initialize the step_index counter for the scheduler. + """ + + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + # Modified from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.step + def step( + self, + model_output: torch.Tensor, + timestep: Union[int, torch.Tensor], + sample: torch.Tensor, + generator=None, + variance_noise: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the multistep DPMSolver. + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + variance_noise (`torch.Tensor`): + Alternative to generating noise with `generator` by directly providing the noise for the variance + itself. Useful for methods such as [`LEdits++`]. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Improve numerical stability for small number of steps + lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( + self.config.euler_at_final or + (self.config.lower_order_final and len(self.timesteps) < 15) or + self.config.final_sigmas_type == "zero") + lower_order_second = ((self.step_index == len(self.timesteps) - 2) and + self.config.lower_order_final and + len(self.timesteps) < 15) + + model_output = self.convert_model_output(model_output, sample=sample) + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.model_outputs[-1] = model_output + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++" + ] and variance_noise is None: + noise = randn_tensor( + model_output.shape, + generator=generator, + device=model_output.device, + dtype=torch.float32) + elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: + noise = variance_noise.to( + device=model_output.device, + dtype=torch.float32) # pyright: ignore + else: + noise = None + + if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: + prev_sample = self.dpm_solver_first_order_update( + model_output, sample=sample, noise=noise) + elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: + prev_sample = self.multistep_dpm_solver_second_order_update( + self.model_outputs, sample=sample, noise=noise) + else: + prev_sample = self.multistep_dpm_solver_third_order_update( + self.model_outputs, sample=sample) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + # Cast sample back to expected dtype + prev_sample = prev_sample.to(model_output.dtype) + + # upon completion increase step index by one + self._step_index += 1 # pyright: ignore + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input + def scale_model_input(self, sample: torch.Tensor, *args, + **kwargs) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + Args: + sample (`torch.Tensor`): + The input sample. + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to( + device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point( + timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to( + original_samples.device, dtype=torch.float32) + timesteps = timesteps.to( + original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [ + self.index_for_timestep(t, schedule_timesteps) + for t in timesteps + ] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + noisy_samples = alpha_t * original_samples + sigma_t * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/humo/models/utils/fm_solvers_unipc.py b/humo/models/utils/fm_solvers_unipc.py new file mode 100644 index 0000000000000000000000000000000000000000..1ed93733d74df502df2a9bf1dc6509c2193368b7 --- /dev/null +++ b/humo/models/utils/fm_solvers_unipc.py @@ -0,0 +1,800 @@ +# Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py +# Convert unipc for flow matching +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import (KarrasDiffusionSchedulers, + SchedulerMixin, + SchedulerOutput) +from diffusers.utils import deprecate, is_scipy_available + +if is_scipy_available(): + import scipy.stats + + +class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin): + """ + `UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + solver_order (`int`, default `2`): + The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1` + due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for + unconditional sampling. + prediction_type (`str`, defaults to "flow_prediction"): + Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts + the flow of the diffusion process. + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`. + predict_x0 (`bool`, defaults to `True`): + Whether to use the updating algorithm on the predicted x0. + solver_type (`str`, default `bh2`): + Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2` + otherwise. + lower_order_final (`bool`, default `True`): + Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can + stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. + disable_corrector (`list`, default `[]`): + Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)` + and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is + usually disabled during the first few steps. + solver_p (`SchedulerMixin`, default `None`): + Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, + the sigmas are determined according to a sequence of noise levels {σi}. + use_exponential_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + final_sigmas_type (`str`, defaults to `"zero"`): + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final + sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + solver_order: int = 2, + prediction_type: str = "flow_prediction", + shift: Optional[float] = 1.0, + use_dynamic_shifting=False, + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + predict_x0: bool = True, + solver_type: str = "bh2", + lower_order_final: bool = True, + disable_corrector: List[int] = [], + solver_p: SchedulerMixin = None, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" + ): + + if solver_type not in ["bh1", "bh2"]: + if solver_type in ["midpoint", "heun", "logrho"]: + self.register_to_config(solver_type="bh2") + else: + raise NotImplementedError( + f"{solver_type} is not implemented for {self.__class__}") + + self.predict_x0 = predict_x0 + # setable values + self.num_inference_steps = None + alphas = np.linspace(1, 1 / num_train_timesteps, + num_train_timesteps)[::-1].copy() + sigmas = 1.0 - alphas + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32) + + if not use_dynamic_shifting: + # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution + sigmas = shift * sigmas / (1 + + (shift - 1) * sigmas) # pyright: ignore + + self.sigmas = sigmas + self.timesteps = sigmas * num_train_timesteps + + self.model_outputs = [None] * solver_order + self.timestep_list = [None] * solver_order + self.lower_order_nums = 0 + self.disable_corrector = disable_corrector + self.solver_p = solver_p + self.last_sample = None + self._step_index = None + self._begin_index = None + + self.sigmas = self.sigmas.to( + "cpu") # to avoid too much CPU/GPU communication + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps + def set_timesteps( + self, + num_inference_steps: Union[int, None] = None, + device: Union[str, torch.device] = None, + sigmas: Optional[List[float]] = None, + mu: Optional[Union[float, None]] = None, + shift: Optional[Union[float, None]] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + Args: + num_inference_steps (`int`): + Total number of the spacing of the time steps. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + + if self.config.use_dynamic_shifting and mu is None: + raise ValueError( + " you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`" + ) + + if sigmas is None: + sigmas = np.linspace(self.sigma_max, self.sigma_min, + num_inference_steps + + 1).copy()[:-1] # pyright: ignore + + if self.config.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore + else: + if shift is None: + shift = self.config.shift + sigmas = shift * sigmas / (1 + + (shift - 1) * sigmas) # pyright: ignore + + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ((1 - self.alphas_cumprod[0]) / + self.alphas_cumprod[0])**0.5 + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + + timesteps = sigmas * self.config.num_train_timesteps + sigmas = np.concatenate([sigmas, [sigma_last] + ]).astype(np.float32) # pyright: ignore + + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(timesteps).to( + device=device, dtype=torch.int64) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + self.last_sample = None + if self.solver_p: + self.solver_p.set_timesteps(self.num_inference_steps, device=device) + + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to( + "cpu") # to avoid too much CPU/GPU communication + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float( + ) # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile( + abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze( + 1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp( + sample, -s, s + ) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + def _sigma_to_alpha_sigma_t(self, sigma): + return 1 - sigma, sigma + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma) + + def convert_model_output( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + r""" + Convert the model output to the corresponding type the UniPC algorithm needs. + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The converted model output. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError( + "missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + + if self.predict_x0: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," + " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + else: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + epsilon = sample - (1 - sigma_t) * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," + " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler." + ) + + if self.config.thresholding: + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + x0_pred = self._threshold_sample(x0_pred) + epsilon = model_output + x0_pred + + return epsilon + + def multistep_uni_p_bh_update( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + order: int = None, # pyright: ignore + **kwargs, + ) -> torch.Tensor: + """ + One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified. + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model at the current timestep. + prev_timestep (`int`): + The previous discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + order (`int`): + The order of UniP at this timestep (corresponds to the *p* in UniPC-p). + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + prev_timestep = args[0] if len(args) > 0 else kwargs.pop( + "prev_timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError( + " missing `sample` as a required keyward argument") + if order is None: + if len(args) > 2: + order = args[2] + else: + raise ValueError( + " missing `order` as a required keyward argument") + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + model_output_list = self.model_outputs + + s0 = self.timestep_list[-1] + m0 = model_output_list[-1] + x = sample + + if self.solver_p: + x_t = self.solver_p.step(model_output, s0, x).prev_sample + return x_t + + sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[ + self.step_index] # pyright: ignore + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + + h = lambda_t - lambda_s0 + device = sample.device + + rks = [] + D1s = [] + for i in range(1, order): + si = self.step_index - i # pyright: ignore + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + D1s.append((mi - m0) / rk) # pyright: ignore + + rks.append(1.0) + rks = torch.tensor(rks, device=device) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=device) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) # (B, K) + # for order 2, we use a simplified version + if order == 2: + rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device) + else: + rhos_p = torch.linalg.solve(R[:-1, :-1], + b[:-1]).to(device).to(x.dtype) + else: + D1s = None + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + pred_res = torch.einsum("k,bkc...->bc...", rhos_p, + D1s) # pyright: ignore + else: + pred_res = 0 + x_t = x_t_ - alpha_t * B_h * pred_res + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + pred_res = torch.einsum("k,bkc...->bc...", rhos_p, + D1s) # pyright: ignore + else: + pred_res = 0 + x_t = x_t_ - sigma_t * B_h * pred_res + + x_t = x_t.to(x.dtype) + return x_t + + def multistep_uni_c_bh_update( + self, + this_model_output: torch.Tensor, + *args, + last_sample: torch.Tensor = None, + this_sample: torch.Tensor = None, + order: int = None, # pyright: ignore + **kwargs, + ) -> torch.Tensor: + """ + One step for the UniC (B(h) version). + + Args: + this_model_output (`torch.Tensor`): + The model outputs at `x_t`. + this_timestep (`int`): + The current timestep `t`. + last_sample (`torch.Tensor`): + The generated sample before the last predictor `x_{t-1}`. + this_sample (`torch.Tensor`): + The generated sample after the last predictor `x_{t}`. + order (`int`): + The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`. + + Returns: + `torch.Tensor`: + The corrected sample tensor at the current timestep. + """ + this_timestep = args[0] if len(args) > 0 else kwargs.pop( + "this_timestep", None) + if last_sample is None: + if len(args) > 1: + last_sample = args[1] + else: + raise ValueError( + " missing`last_sample` as a required keyward argument") + if this_sample is None: + if len(args) > 2: + this_sample = args[2] + else: + raise ValueError( + " missing`this_sample` as a required keyward argument") + if order is None: + if len(args) > 3: + order = args[3] + else: + raise ValueError( + " missing`order` as a required keyward argument") + if this_timestep is not None: + deprecate( + "this_timestep", + "1.0.0", + "Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + model_output_list = self.model_outputs + + m0 = model_output_list[-1] + x = last_sample + x_t = this_sample + model_t = this_model_output + + sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[ + self.step_index - 1] # pyright: ignore + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + + h = lambda_t - lambda_s0 + device = this_sample.device + + rks = [] + D1s = [] + for i in range(1, order): + si = self.step_index - (i + 1) # pyright: ignore + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + D1s.append((mi - m0) / rk) # pyright: ignore + + rks.append(1.0) + rks = torch.tensor(rks, device=device) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=device) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) + else: + D1s = None + + # for order 1, we use a simplified version + if order == 1: + rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device) + else: + rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype) + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t) + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t) + x_t = x_t.to(x.dtype) + return x_t + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index + def _init_step_index(self, timestep): + """ + Initialize the step_index counter for the scheduler. + """ + + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step(self, + model_output: torch.Tensor, + timestep: Union[int, torch.Tensor], + sample: torch.Tensor, + return_dict: bool = True, + generator=None) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the multistep UniPC. + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + use_corrector = ( + self.step_index > 0 and + self.step_index - 1 not in self.disable_corrector and + self.last_sample is not None # pyright: ignore + ) + + model_output_convert = self.convert_model_output( + model_output, sample=sample) + if use_corrector: + sample = self.multistep_uni_c_bh_update( + this_model_output=model_output_convert, + last_sample=self.last_sample, + this_sample=sample, + order=self.this_order, + ) + + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.timestep_list[i] = self.timestep_list[i + 1] + + self.model_outputs[-1] = model_output_convert + self.timestep_list[-1] = timestep # pyright: ignore + + if self.config.lower_order_final: + this_order = min(self.config.solver_order, + len(self.timesteps) - + self.step_index) # pyright: ignore + else: + this_order = self.config.solver_order + + self.this_order = min(this_order, + self.lower_order_nums + 1) # warmup for multistep + assert self.this_order > 0 + + self.last_sample = sample + prev_sample = self.multistep_uni_p_bh_update( + model_output=model_output, # pass the original non-converted model output, in case solver-p is used + sample=sample, + order=self.this_order, + ) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + # upon completion increase step index by one + self._step_index += 1 # pyright: ignore + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def scale_model_input(self, sample: torch.Tensor, *args, + **kwargs) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to( + device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point( + timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to( + original_samples.device, dtype=torch.float32) + timesteps = timesteps.to( + original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [ + self.index_for_timestep(t, schedule_timesteps) + for t in timesteps + ] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + noisy_samples = alpha_t * original_samples + sigma_t * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/humo/models/utils/utils.py b/humo/models/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2b30b2e95c73637ae87cb73452173e0592dba34c --- /dev/null +++ b/humo/models/utils/utils.py @@ -0,0 +1,58 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import argparse +import binascii +import os +import os.path as osp +import json +from omegaconf import OmegaConf + +import imageio +import torch +import torchvision +from moviepy.editor import AudioFileClip, VideoClip + +__all__ = ['tensor_to_video', 'prepare_json_dataset'] + + +def tensor_to_video(tensor, output_video_path, input_audio_path, fps=25): + """ + Converts a Tensor with shape [c, f, h, w] into a video and adds an audio track from the specified audio file. + + Args: + tensor (numpy): The Tensor to be converted, shaped [f, h, w, c]. + output_video_path (str): The file path where the output video will be saved. + input_audio_path (str): The path to the audio file (WAV file) that contains the audio track to be added. + fps (int): The frame rate of the output video. Default is 30 fps. + """ + def make_frame(t): + frame_index = min(int(t * fps), tensor.shape[0] - 1) + return tensor[frame_index] + + video_duration = tensor.shape[0] / fps + audio_clip = AudioFileClip(input_audio_path) + audio_duration = audio_clip.duration + final_duration = min(video_duration, audio_duration) + audio_clip = audio_clip.subclip(0, final_duration) + new_video_clip = VideoClip(make_frame, duration=final_duration) + new_video_clip = new_video_clip.set_audio(audio_clip) + new_video_clip.write_videofile(output_video_path, fps=fps, audio_codec="aac") + + +def prepare_json_dataset(json_path): + samples = [] + with open(json_path, "rb") as f: + data = json.load(f) + for itemname, row in data.items(): + text = row['prompt'].strip().replace("_", " ").strip('"') + audio_path = row['audio_path'] + ref_img_path = [x for x in row['img_paths']] + + samples.append({ + "text": text, + "ref_img": ref_img_path, + "audio": audio_path, + "itemname": itemname + }) + samples = OmegaConf.create(samples) + + return samples \ No newline at end of file diff --git a/humo/models/wan_modules/__init__.py b/humo/models/wan_modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6c7e79b0864945199247cabc9b48138e411b3216 --- /dev/null +++ b/humo/models/wan_modules/__init__.py @@ -0,0 +1,16 @@ +from .attention import flash_attention +from .model import WanModel +from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model +from .tokenizers import HuggingfaceTokenizer +from .vae import WanVAE + +__all__ = [ + 'WanVAE', + 'WanModel', + 'T5Model', + 'T5Encoder', + 'T5Decoder', + 'T5EncoderModel', + 'HuggingfaceTokenizer', + 'flash_attention', +] diff --git a/humo/models/wan_modules/attention.py b/humo/models/wan_modules/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..76d26535e1d7e106dc8889e0a64549ebc2aad600 --- /dev/null +++ b/humo/models/wan_modules/attention.py @@ -0,0 +1,256 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import warnings +import torch +from typing import Optional, Tuple + +try: + import flash_attn_interface + FLASH_ATTN_3_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_3_AVAILABLE = False + +try: + import flash_attn + FLASH_ATTN_2_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_2_AVAILABLE = False + + +__all__ = [ + 'flash_attention', + 'attention', +] + + +# --------------------------- +# Custom op + fake kernel +# --------------------------- +from typing import Optional, Sequence # <- add Sequence + +# ... imports unchanged ... +from typing import Optional, Sequence + +@torch.library.custom_op("wan::flash_attention", mutates_args=()) +def _wan_flash_attention_op( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + q_lens: Optional[torch.Tensor] = None, + k_lens: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + q_scale: Optional[float] = None, + causal: bool = False, + # IMPORTANT: schema-friendly default (None), not a tuple + window_size: Optional[Sequence[int]] = None, + deterministic: bool = False, + dtype: torch.dtype = torch.bfloat16, + version: Optional[int] = None, +) -> torch.Tensor: + half_dtypes = (torch.float16, torch.bfloat16) + assert dtype in half_dtypes + assert q.size(-1) <= 256 + + # normalize window_size to a 2-tuple for FA2 API + if window_size is None: + ws = (-1, -1) + else: + ws = tuple(window_size) + if len(ws) != 2: + raise ValueError(f"window_size must have length 2; got {window_size!r}") + + b, lq, nheads = q.shape[0], q.shape[1], q.shape[2] + lk = k.shape[1] + out_dtype = q.dtype + + def half(x: torch.Tensor) -> torch.Tensor: + return x if x.dtype in half_dtypes else x.to(dtype) + + # --- preprocess (unchanged) --- + if q_lens is None: + q_flat = half(q.flatten(0, 1)) + q_lens = torch.tensor([lq] * b, dtype=torch.int32) + else: + q_flat = half(torch.cat([u[:v] for u, v in zip(q, q_lens)])) + + if k_lens is None: + k_flat = half(k.flatten(0, 1)) + v_flat = half(v.flatten(0, 1)) + k_lens = torch.tensor([lk] * b, dtype=torch.int32) + else: + k_flat = half(torch.cat([u[:v] for u, v in zip(k, k_lens)])) + v_flat = half(torch.cat([u[:v] for u, v in zip(v, k_lens)])) + + q_flat = q_flat.to(v_flat.dtype); k_flat = k_flat.to(v_flat.dtype) + if q_scale is not None: + q_flat = q_flat * q_scale + + if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE: + warnings.warn('Flash attention 3 is not available, use flash attention 2 instead.') + + if FLASH_ATTN_3_AVAILABLE: + ret = flash_attn_interface.flash_attn_varlen_func( + q=q_flat, + k=k_flat, + v=v_flat, + cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32).to(q_flat.device, non_blocking=True), + cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32).to(k_flat.device, non_blocking=True), + seqused_q=None, + seqused_k=None, + max_seqlen_q=lq, + max_seqlen_k=lk, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic, + ) + out0 = ret[0] if isinstance(ret, (tuple, list)) else ret + total_q = b * lq + if out0.dim() != 3: + raise RuntimeError(f"Unexpected FA3 output rank {out0.dim()} shape={tuple(out0.shape)}") + if out0.shape[0] == total_q: + out_flat = out0 + elif out0.shape[0] == nheads and out0.shape[1] == total_q: + out_flat = out0.transpose(0, 1).contiguous() + else: + raise RuntimeError(f"Unexpected FA3 output shape {tuple(out0.shape)}") + out = out_flat.unflatten(0, (b, lq)) + + elif FLASH_ATTN_2_AVAILABLE: + out = flash_attn.flash_attn_varlen_func( + q=q_flat, + k=k_flat, + v=v_flat, + cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32).to(q_flat.device, non_blocking=True), + cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32).to(q_flat.device, non_blocking=True), + max_seqlen_q=lq, + max_seqlen_k=lk, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + window_size=ws, # <- pass 2-tuple + deterministic=deterministic, + ).unflatten(0, (b, lq)) + else: + q_s = q.transpose(1, 2).to(dtype) + k_s = k.transpose(1, 2).to(dtype) + v_s = v.transpose(1, 2).to(dtype) + out = torch.nn.functional.scaled_dot_product_attention( + q_s, k_s, v_s, attn_mask=None, is_causal=causal, dropout_p=dropout_p + ).transpose(1, 2).contiguous() + + return out.to(out_dtype) + +@_wan_flash_attention_op.register_fake +def _wan_flash_attention_op_fake( + q, + k, + v, + q_lens=None, + k_lens=None, + dropout_p: float = 0.0, + softmax_scale=None, + q_scale=None, + causal: bool = False, + window_size: Optional[Sequence[int]] = None, + deterministic: bool = False, + dtype: torch.dtype = torch.bfloat16, + version: Optional[int] = None, +): + # Match output shape: (B, Lq, Nq, Dh_v) and keep the SAME fake device as `q` + B, Lq, Nq, _ = q.shape + Dh_v = v.shape[-1] + return q.new_empty((B, Lq, Nq, Dh_v), dtype=q.dtype) + + + +# --------------------------- +# Public API (unchanged signature) +# --------------------------- +def flash_attention( + q, + k, + v, + q_lens=None, + k_lens=None, + dropout_p=0., + softmax_scale=None, + q_scale=None, + causal=False, + window_size=(-1, -1), + deterministic=False, + dtype=torch.bfloat16, + version=None, +): + """ + q: [B, Lq, Nq, C1]. + k: [B, Lk, Nk, C1]. + v: [B, Lk, Nk, C2]. Nq must be divisible by Nk. + q_lens: [B]. + k_lens: [B]. + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + causal: bool. Whether to apply causal attention mask. + window_size: (left right). If not (-1, -1), apply sliding window local attention. + deterministic: bool. If True, slightly slower and uses more memory. + dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16. + """ + # Simply delegate to the custom op so Dynamo/AOT treats it as a single node; + # our eager kernel inside _wan_flash_attention_op keeps the original behavior. + return _wan_flash_attention_op( + q, k, v, + q_lens=q_lens, + k_lens=k_lens, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + q_scale=q_scale, + causal=causal, + window_size=window_size, + deterministic=deterministic, + dtype=dtype, + version=version, + ) + + +def attention( + q, + k, + v, + q_lens=None, + k_lens=None, + dropout_p=0., + softmax_scale=None, + q_scale=None, + causal=False, + window_size=(-1, -1), + deterministic=False, + dtype=torch.bfloat16, + fa_version=None, +): + if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE: + return flash_attention( + q=q, + k=k, + v=v, + q_lens=q_lens, + k_lens=k_lens, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + q_scale=q_scale, + causal=causal, + window_size=window_size, + deterministic=deterministic, + dtype=dtype, + version=fa_version, + ) + else: + if q_lens is not None or k_lens is not None: + warnings.warn( + 'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.' + ) + q_ = q.transpose(1, 2).to(dtype) + k_ = k.transpose(1, 2).to(dtype) + v_ = v.transpose(1, 2).to(dtype) + out = torch.nn.functional.scaled_dot_product_attention( + q_, k_, v_, attn_mask=None, is_causal=causal, dropout_p=dropout_p + ) + return out.transpose(1, 2).contiguous() diff --git a/humo/models/wan_modules/clip.py b/humo/models/wan_modules/clip.py new file mode 100644 index 0000000000000000000000000000000000000000..8659c4f13ce4da36110edb49a0ac71ccd5e3f841 --- /dev/null +++ b/humo/models/wan_modules/clip.py @@ -0,0 +1,542 @@ +# Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip'' +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import logging +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms as T + +from .attention import flash_attention +from .tokenizers import HuggingfaceTokenizer +from .xlm_roberta import XLMRoberta + +__all__ = [ + 'XLMRobertaCLIP', + 'clip_xlm_roberta_vit_h_14', + 'CLIPModel', +] + + +def pos_interpolate(pos, seq_len): + if pos.size(1) == seq_len: + return pos + else: + src_grid = int(math.sqrt(pos.size(1))) + tar_grid = int(math.sqrt(seq_len)) + n = pos.size(1) - src_grid * src_grid + return torch.cat([ + pos[:, :n], + F.interpolate( + pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute( + 0, 3, 1, 2), + size=(tar_grid, tar_grid), + mode='bicubic', + align_corners=False).flatten(2).transpose(1, 2) + ], + dim=1) + + +class QuickGELU(nn.Module): + + def forward(self, x): + return x * torch.sigmoid(1.702 * x) + + +class LayerNorm(nn.LayerNorm): + + def forward(self, x): + return super().forward(x.float()).type_as(x) + + +class SelfAttention(nn.Module): + + def __init__(self, + dim, + num_heads, + causal=False, + attn_dropout=0.0, + proj_dropout=0.0): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.causal = causal + self.attn_dropout = attn_dropout + self.proj_dropout = proj_dropout + + # layers + self.to_qkv = nn.Linear(dim, dim * 3) + self.proj = nn.Linear(dim, dim) + + def forward(self, x): + """ + x: [B, L, C]. + """ + b, s, c, n, d = *x.size(), self.num_heads, self.head_dim + + # compute query, key, value + q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2) + + # compute attention + p = self.attn_dropout if self.training else 0.0 + x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2) + x = x.reshape(b, s, c) + + # output + x = self.proj(x) + x = F.dropout(x, self.proj_dropout, self.training) + return x + + +class SwiGLU(nn.Module): + + def __init__(self, dim, mid_dim): + super().__init__() + self.dim = dim + self.mid_dim = mid_dim + + # layers + self.fc1 = nn.Linear(dim, mid_dim) + self.fc2 = nn.Linear(dim, mid_dim) + self.fc3 = nn.Linear(mid_dim, dim) + + def forward(self, x): + x = F.silu(self.fc1(x)) * self.fc2(x) + x = self.fc3(x) + return x + + +class AttentionBlock(nn.Module): + + def __init__(self, + dim, + mlp_ratio, + num_heads, + post_norm=False, + causal=False, + activation='quick_gelu', + attn_dropout=0.0, + proj_dropout=0.0, + norm_eps=1e-5): + assert activation in ['quick_gelu', 'gelu', 'swi_glu'] + super().__init__() + self.dim = dim + self.mlp_ratio = mlp_ratio + self.num_heads = num_heads + self.post_norm = post_norm + self.causal = causal + self.norm_eps = norm_eps + + # layers + self.norm1 = LayerNorm(dim, eps=norm_eps) + self.attn = SelfAttention(dim, num_heads, causal, attn_dropout, + proj_dropout) + self.norm2 = LayerNorm(dim, eps=norm_eps) + if activation == 'swi_glu': + self.mlp = SwiGLU(dim, int(dim * mlp_ratio)) + else: + self.mlp = nn.Sequential( + nn.Linear(dim, int(dim * mlp_ratio)), + QuickGELU() if activation == 'quick_gelu' else nn.GELU(), + nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout)) + + def forward(self, x): + if self.post_norm: + x = x + self.norm1(self.attn(x)) + x = x + self.norm2(self.mlp(x)) + else: + x = x + self.attn(self.norm1(x)) + x = x + self.mlp(self.norm2(x)) + return x + + +class AttentionPool(nn.Module): + + def __init__(self, + dim, + mlp_ratio, + num_heads, + activation='gelu', + proj_dropout=0.0, + norm_eps=1e-5): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.mlp_ratio = mlp_ratio + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.proj_dropout = proj_dropout + self.norm_eps = norm_eps + + # layers + gain = 1.0 / math.sqrt(dim) + self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim)) + self.to_q = nn.Linear(dim, dim) + self.to_kv = nn.Linear(dim, dim * 2) + self.proj = nn.Linear(dim, dim) + self.norm = LayerNorm(dim, eps=norm_eps) + self.mlp = nn.Sequential( + nn.Linear(dim, int(dim * mlp_ratio)), + QuickGELU() if activation == 'quick_gelu' else nn.GELU(), + nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout)) + + def forward(self, x): + """ + x: [B, L, C]. + """ + b, s, c, n, d = *x.size(), self.num_heads, self.head_dim + + # compute query, key, value + q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1) + k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2) + + # compute attention + x = flash_attention(q, k, v, version=2) + x = x.reshape(b, 1, c) + + # output + x = self.proj(x) + x = F.dropout(x, self.proj_dropout, self.training) + + # mlp + x = x + self.mlp(self.norm(x)) + return x[:, 0] + + +class VisionTransformer(nn.Module): + + def __init__(self, + image_size=224, + patch_size=16, + dim=768, + mlp_ratio=4, + out_dim=512, + num_heads=12, + num_layers=12, + pool_type='token', + pre_norm=True, + post_norm=False, + activation='quick_gelu', + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0, + norm_eps=1e-5): + if image_size % patch_size != 0: + print( + '[WARNING] image_size is not divisible by patch_size', + flush=True) + assert pool_type in ('token', 'token_fc', 'attn_pool') + out_dim = out_dim or dim + super().__init__() + self.image_size = image_size + self.patch_size = patch_size + self.num_patches = (image_size // patch_size)**2 + self.dim = dim + self.mlp_ratio = mlp_ratio + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.pool_type = pool_type + self.post_norm = post_norm + self.norm_eps = norm_eps + + # embeddings + gain = 1.0 / math.sqrt(dim) + self.patch_embedding = nn.Conv2d( + 3, + dim, + kernel_size=patch_size, + stride=patch_size, + bias=not pre_norm) + if pool_type in ('token', 'token_fc'): + self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim)) + self.pos_embedding = nn.Parameter(gain * torch.randn( + 1, self.num_patches + + (1 if pool_type in ('token', 'token_fc') else 0), dim)) + self.dropout = nn.Dropout(embedding_dropout) + + # transformer + self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None + self.transformer = nn.Sequential(*[ + AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False, + activation, attn_dropout, proj_dropout, norm_eps) + for _ in range(num_layers) + ]) + self.post_norm = LayerNorm(dim, eps=norm_eps) + + # head + if pool_type == 'token': + self.head = nn.Parameter(gain * torch.randn(dim, out_dim)) + elif pool_type == 'token_fc': + self.head = nn.Linear(dim, out_dim) + elif pool_type == 'attn_pool': + self.head = AttentionPool(dim, mlp_ratio, num_heads, activation, + proj_dropout, norm_eps) + + def forward(self, x, interpolation=False, use_31_block=False): + b = x.size(0) + + # embeddings + x = self.patch_embedding(x).flatten(2).permute(0, 2, 1) + if self.pool_type in ('token', 'token_fc'): + x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1) + if interpolation: + e = pos_interpolate(self.pos_embedding, x.size(1)) + else: + e = self.pos_embedding + x = self.dropout(x + e) + if self.pre_norm is not None: + x = self.pre_norm(x) + + # transformer + if use_31_block: + x = self.transformer[:-1](x) + return x + else: + x = self.transformer(x) + return x + + +class XLMRobertaWithHead(XLMRoberta): + + def __init__(self, **kwargs): + self.out_dim = kwargs.pop('out_dim') + super().__init__(**kwargs) + + # head + mid_dim = (self.dim + self.out_dim) // 2 + self.head = nn.Sequential( + nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(), + nn.Linear(mid_dim, self.out_dim, bias=False)) + + def forward(self, ids): + # xlm-roberta + x = super().forward(ids) + + # average pooling + mask = ids.ne(self.pad_id).unsqueeze(-1).to(x) + x = (x * mask).sum(dim=1) / mask.sum(dim=1) + + # head + x = self.head(x) + return x + + +class XLMRobertaCLIP(nn.Module): + + def __init__(self, + embed_dim=1024, + image_size=224, + patch_size=14, + vision_dim=1280, + vision_mlp_ratio=4, + vision_heads=16, + vision_layers=32, + vision_pool='token', + vision_pre_norm=True, + vision_post_norm=False, + activation='gelu', + vocab_size=250002, + max_text_len=514, + type_size=1, + pad_id=1, + text_dim=1024, + text_heads=16, + text_layers=24, + text_post_norm=True, + text_dropout=0.1, + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0, + norm_eps=1e-5): + super().__init__() + self.embed_dim = embed_dim + self.image_size = image_size + self.patch_size = patch_size + self.vision_dim = vision_dim + self.vision_mlp_ratio = vision_mlp_ratio + self.vision_heads = vision_heads + self.vision_layers = vision_layers + self.vision_pre_norm = vision_pre_norm + self.vision_post_norm = vision_post_norm + self.activation = activation + self.vocab_size = vocab_size + self.max_text_len = max_text_len + self.type_size = type_size + self.pad_id = pad_id + self.text_dim = text_dim + self.text_heads = text_heads + self.text_layers = text_layers + self.text_post_norm = text_post_norm + self.norm_eps = norm_eps + + # models + self.visual = VisionTransformer( + image_size=image_size, + patch_size=patch_size, + dim=vision_dim, + mlp_ratio=vision_mlp_ratio, + out_dim=embed_dim, + num_heads=vision_heads, + num_layers=vision_layers, + pool_type=vision_pool, + pre_norm=vision_pre_norm, + post_norm=vision_post_norm, + activation=activation, + attn_dropout=attn_dropout, + proj_dropout=proj_dropout, + embedding_dropout=embedding_dropout, + norm_eps=norm_eps) + self.textual = XLMRobertaWithHead( + vocab_size=vocab_size, + max_seq_len=max_text_len, + type_size=type_size, + pad_id=pad_id, + dim=text_dim, + out_dim=embed_dim, + num_heads=text_heads, + num_layers=text_layers, + post_norm=text_post_norm, + dropout=text_dropout) + self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([])) + + def forward(self, imgs, txt_ids): + """ + imgs: [B, 3, H, W] of torch.float32. + - mean: [0.48145466, 0.4578275, 0.40821073] + - std: [0.26862954, 0.26130258, 0.27577711] + txt_ids: [B, L] of torch.long. + Encoded by data.CLIPTokenizer. + """ + xi = self.visual(imgs) + xt = self.textual(txt_ids) + return xi, xt + + def param_groups(self): + groups = [{ + 'params': [ + p for n, p in self.named_parameters() + if 'norm' in n or n.endswith('bias') + ], + 'weight_decay': 0.0 + }, { + 'params': [ + p for n, p in self.named_parameters() + if not ('norm' in n or n.endswith('bias')) + ] + }] + return groups + + +def _clip(pretrained=False, + pretrained_name=None, + model_cls=XLMRobertaCLIP, + return_transforms=False, + return_tokenizer=False, + tokenizer_padding='eos', + dtype=torch.float32, + device='cpu', + **kwargs): + # init a model on device + with torch.device(device): + model = model_cls(**kwargs) + + # set device + model = model.to(dtype=dtype, device=device) + output = (model,) + + # init transforms + if return_transforms: + # mean and std + if 'siglip' in pretrained_name.lower(): + mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] + else: + mean = [0.48145466, 0.4578275, 0.40821073] + std = [0.26862954, 0.26130258, 0.27577711] + + # transforms + transforms = T.Compose([ + T.Resize((model.image_size, model.image_size), + interpolation=T.InterpolationMode.BICUBIC), + T.ToTensor(), + T.Normalize(mean=mean, std=std) + ]) + output += (transforms,) + return output[0] if len(output) == 1 else output + + +def clip_xlm_roberta_vit_h_14( + pretrained=False, + pretrained_name='open-clip-xlm-roberta-large-vit-huge-14', + **kwargs): + cfg = dict( + embed_dim=1024, + image_size=224, + patch_size=14, + vision_dim=1280, + vision_mlp_ratio=4, + vision_heads=16, + vision_layers=32, + vision_pool='token', + activation='gelu', + vocab_size=250002, + max_text_len=514, + type_size=1, + pad_id=1, + text_dim=1024, + text_heads=16, + text_layers=24, + text_post_norm=True, + text_dropout=0.1, + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0) + cfg.update(**kwargs) + return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg) + + +class CLIPModel: + + def __init__(self, dtype, device, checkpoint_path, tokenizer_path): + self.dtype = dtype + self.device = device + self.checkpoint_path = checkpoint_path + self.tokenizer_path = tokenizer_path + + # init model + self.model, self.transforms = clip_xlm_roberta_vit_h_14( + pretrained=False, + return_transforms=True, + return_tokenizer=False, + dtype=dtype, + device=device) + self.model = self.model.eval().requires_grad_(False) + logging.info(f'loading {checkpoint_path}') + self.model.load_state_dict( + torch.load(checkpoint_path, map_location='cpu')) + + # init tokenizer + self.tokenizer = HuggingfaceTokenizer( + name=tokenizer_path, + seq_len=self.model.max_text_len - 2, + clean='whitespace') + + def visual(self, videos): + # preprocess + size = (self.model.image_size,) * 2 + videos = torch.cat([ + F.interpolate( + u.transpose(0, 1), + size=size, + mode='bicubic', + align_corners=False) for u in videos + ]) + videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5)) + + # forward + with torch.amp.autocast('cuda', dtype=self.dtype): + out = self.model.visual(videos, use_31_block=True) + return out diff --git a/humo/models/wan_modules/model.py b/humo/models/wan_modules/model.py new file mode 100644 index 0000000000000000000000000000000000000000..1a9c0f3e5869d4e1d1f901f11e58e4678a850dab --- /dev/null +++ b/humo/models/wan_modules/model.py @@ -0,0 +1,619 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import math + +import torch +import torch.cuda.amp as amp +import torch.nn as nn +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin + +from .attention import flash_attention + +__all__ = ['WanModel'] + + +def sinusoidal_embedding_1d(dim, position): + # preprocess + assert dim % 2 == 0 + half = dim // 2 + position = position.type(torch.float64) + + # calculation + sinusoid = torch.outer( + position, torch.pow(10000, -torch.arange(half).to(position).div(half))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + return x + + +@torch.amp.autocast("cuda", enabled=False) +def rope_params(max_seq_len, dim, theta=10000): + assert dim % 2 == 0 + freqs = torch.outer( + torch.arange(max_seq_len), + 1.0 / torch.pow(theta, + torch.arange(0, dim, 2).to(torch.float64).div(dim))) + freqs = torch.polar(torch.ones_like(freqs), freqs) + return freqs + + +@torch.amp.autocast("cuda", enabled=False) +def rope_apply(x, grid_sizes, freqs): + n, c = x.size(2), x.size(3) // 2 + + # split freqs + freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) + + # loop over samples + output = [] + for i, (f, h, w) in enumerate(grid_sizes.tolist()): + seq_len = f * h * w + + # precompute multipliers + x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape( + seq_len, n, -1, 2)) + freqs_i = torch.cat([ + freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], + dim=-1).reshape(seq_len, 1, -1) + + # apply rotary embedding + x_i = torch.view_as_real(x_i * freqs_i).flatten(2) + x_i = torch.cat([x_i, x[i, seq_len:]]) + + # append to collection + output.append(x_i) + return torch.stack(output).float() + + +class WanRMSNorm(nn.Module): + + def __init__(self, dim, eps=1e-5): + super().__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + r""" + Args: + x(Tensor): Shape [B, L, C] + """ + return self._norm(x.float()).type_as(x) * self.weight + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) + + +class WanLayerNorm(nn.LayerNorm): + + def __init__(self, dim, eps=1e-6, elementwise_affine=False): + super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps) + + def forward(self, x): + r""" + Args: + x(Tensor): Shape [B, L, C] + """ + return super().forward(x.float()).type_as(x) + + +class WanSelfAttention(nn.Module): + + def __init__(self, + dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + eps=1e-6): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.window_size = window_size + self.qk_norm = qk_norm + self.eps = eps + + # layers + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + + def forward(self, x, seq_lens, grid_sizes, freqs): + r""" + Args: + x(Tensor): Shape [B, L, num_heads, C / num_heads] + seq_lens(Tensor): Shape [B] + grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) + freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + """ + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim + + # query, key, value function + def qkv_fn(x): + q = self.norm_q(self.q(x)).view(b, s, n, d) + k = self.norm_k(self.k(x)).view(b, s, n, d) + v = self.v(x).view(b, s, n, d) + return q, k, v + + q, k, v = qkv_fn(x) + + x = flash_attention( + q=rope_apply(q, grid_sizes, freqs), + k=rope_apply(k, grid_sizes, freqs), + v=v, + k_lens=seq_lens, + window_size=self.window_size) + + # output + x = x.flatten(2) + x = self.o(x) + return x + + +class WanT2VCrossAttention(WanSelfAttention): + + def forward(self, x, context, context_lens): + r""" + Args: + x(Tensor): Shape [B, L1, C] + context(Tensor): Shape [B, L2, C] + context_lens(Tensor): Shape [B] + """ + b, n, d = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + q = self.norm_q(self.q(x)).view(b, -1, n, d) + k = self.norm_k(self.k(context)).view(b, -1, n, d) + v = self.v(context).view(b, -1, n, d) + + # compute attention + x = flash_attention(q, k, v, k_lens=context_lens) + + # output + x = x.flatten(2) + x = self.o(x) + return x + + +class WanI2VCrossAttention(WanSelfAttention): + + def __init__(self, + dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + eps=1e-6): + super().__init__(dim, num_heads, window_size, qk_norm, eps) + + self.k_img = nn.Linear(dim, dim) + self.v_img = nn.Linear(dim, dim) + # self.alpha = nn.Parameter(torch.zeros((1, ))) + self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + + def forward(self, x, context, context_lens): + r""" + Args: + x(Tensor): Shape [B, L1, C] + context(Tensor): Shape [B, L2, C] + context_lens(Tensor): Shape [B] + """ + context_img = context[:, :257] + context = context[:, 257:] + b, n, d = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + q = self.norm_q(self.q(x)).view(b, -1, n, d) + k = self.norm_k(self.k(context)).view(b, -1, n, d) + v = self.v(context).view(b, -1, n, d) + k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d) + v_img = self.v_img(context_img).view(b, -1, n, d) + img_x = flash_attention(q, k_img, v_img, k_lens=None) + # compute attention + x = flash_attention(q, k, v, k_lens=context_lens) + + # output + x = x.flatten(2) + img_x = img_x.flatten(2) + x = x + img_x + x = self.o(x) + return x + + +WAN_CROSSATTENTION_CLASSES = { + 't2v_cross_attn': WanT2VCrossAttention, + 'i2v_cross_attn': WanI2VCrossAttention, +} + + +class WanAttentionBlock(nn.Module): + + def __init__(self, + cross_attn_type, + dim, + ffn_dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=False, + eps=1e-6): + super().__init__() + self.dim = dim + self.ffn_dim = ffn_dim + self.num_heads = num_heads + self.window_size = window_size + self.qk_norm = qk_norm + self.cross_attn_norm = cross_attn_norm + self.eps = eps + + # layers + self.norm1 = WanLayerNorm(dim, eps) + self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, + eps) + self.norm3 = WanLayerNorm( + dim, eps, + elementwise_affine=True) if cross_attn_norm else nn.Identity() + self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim, + num_heads, + (-1, -1), + qk_norm, + eps) + self.norm2 = WanLayerNorm(dim, eps) + self.ffn = nn.Sequential( + nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'), + nn.Linear(ffn_dim, dim)) + + # modulation + self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + + def forward( + self, + x, + e, + seq_lens, + grid_sizes, + freqs, + context, + context_lens, + ): + r""" + Args: + x(Tensor): Shape [B, L, C] + e(Tensor): Shape [B, 6, C] + seq_lens(Tensor): Shape [B], length of each sequence in batch + grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) + freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + """ + assert e.dtype == torch.float32 + with torch.amp.autocast('cuda', dtype=torch.float32): + e = (self.modulation + e).chunk(6, dim=1) + assert e[0].dtype == torch.float32 + + # self-attention + y = self.self_attn( + self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes, + freqs) + with torch.amp.autocast('cuda', dtype=torch.float32): + x = x + y * e[2] + + # cross-attention & ffn function + def cross_attn_ffn(x, context, context_lens, e): + x = x + self.cross_attn(self.norm3(x), context, context_lens) + y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3]) + with torch.amp.autocast('cuda', dtype=torch.float32): + x = x + y * e[5] + return x + + x = cross_attn_ffn(x, context, context_lens, e) + return x + + +class Head(nn.Module): + + def __init__(self, dim, out_dim, patch_size, eps=1e-6): + super().__init__() + self.dim = dim + self.out_dim = out_dim + self.patch_size = patch_size + self.eps = eps + + # layers + out_dim = math.prod(patch_size) * out_dim + self.norm = WanLayerNorm(dim, eps) + self.head = nn.Linear(dim, out_dim) + + # modulation + self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) + + def forward(self, x, e): + r""" + Args: + x(Tensor): Shape [B, L1, C] + e(Tensor): Shape [B, C] + """ + assert e.dtype == torch.float32 + with torch.amp.autocast('cuda', dtype=torch.float32): + e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1) + x = (self.head(self.norm(x) * (1 + e[1]) + e[0])) + return x + + +class MLPProj(torch.nn.Module): + + def __init__(self, in_dim, out_dim): + super().__init__() + + self.proj = torch.nn.Sequential( + torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim), + torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim), + torch.nn.LayerNorm(out_dim)) + + def forward(self, image_embeds): + clip_extra_context_tokens = self.proj(image_embeds) + return clip_extra_context_tokens + + +class WanModel(ModelMixin, ConfigMixin): + r""" + Wan diffusion backbone supporting both text-to-video and image-to-video. + """ + + ignore_for_config = [ + 'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size' + ] + _no_split_modules = ['WanAttentionBlock'] + + @register_to_config + def __init__(self, + model_type='t2v', + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=5120, + ffn_dim=13824, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=40, + num_layers=40, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + eps=1e-6): + r""" + Initialize the diffusion model backbone. + + Args: + model_type (`str`, *optional*, defaults to 't2v'): + Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) + patch_size (`tuple`, *optional*, defaults to (1, 2, 2)): + 3D patch dimensions for video embedding (t_patch, h_patch, w_patch) + text_len (`int`, *optional*, defaults to 512): + Fixed length for text embeddings + in_dim (`int`, *optional*, defaults to 16): + Input video channels (C_in) + dim (`int`, *optional*, defaults to 2048): + Hidden dimension of the transformer + ffn_dim (`int`, *optional*, defaults to 8192): + Intermediate dimension in feed-forward network + freq_dim (`int`, *optional*, defaults to 256): + Dimension for sinusoidal time embeddings + text_dim (`int`, *optional*, defaults to 4096): + Input dimension for text embeddings + out_dim (`int`, *optional*, defaults to 16): + Output video channels (C_out) + num_heads (`int`, *optional*, defaults to 16): + Number of attention heads + num_layers (`int`, *optional*, defaults to 32): + Number of transformer blocks + window_size (`tuple`, *optional*, defaults to (-1, -1)): + Window size for local attention (-1 indicates global attention) + qk_norm (`bool`, *optional*, defaults to True): + Enable query/key normalization + cross_attn_norm (`bool`, *optional*, defaults to False): + Enable cross-attention normalization + eps (`float`, *optional*, defaults to 1e-6): + Epsilon value for normalization layers + """ + + super().__init__() + + assert model_type in ['t2v', 'i2v'] + self.model_type = model_type + + self.patch_size = patch_size + self.text_len = text_len + self.in_dim = in_dim + self.dim = dim + self.ffn_dim = ffn_dim + self.freq_dim = freq_dim + self.text_dim = text_dim + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.window_size = window_size + self.qk_norm = qk_norm + self.cross_attn_norm = cross_attn_norm + self.eps = eps + + # embeddings + self.patch_embedding = nn.Conv3d( + in_dim, dim, kernel_size=patch_size, stride=patch_size) + self.text_embedding = nn.Sequential( + nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'), + nn.Linear(dim, dim)) + + self.time_embedding = nn.Sequential( + nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) + self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6)) + + # blocks + cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn' + self.blocks = nn.ModuleList([ + WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads, + window_size, qk_norm, cross_attn_norm, eps) + for _ in range(num_layers) + ]) + + # head + self.head = Head(dim, out_dim, patch_size, eps) + + # buffers (don't use register_buffer otherwise dtype will be changed in to()) + assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0 + d = dim // num_heads + self.freqs = torch.cat([ + rope_params(1024, d - 4 * (d // 6)), + rope_params(1024, 2 * (d // 6)), + rope_params(1024, 2 * (d // 6)) + ], + dim=1) + + if model_type == 'i2v': + self.img_emb = MLPProj(1280, dim) + + # initialize weights + self.init_weights() + + def forward( + self, + x, + t, + context, + seq_len, + clip_fea=None, + y=None, + ): + r""" + Forward pass through the diffusion model + + Args: + x (List[Tensor]): + List of input video tensors, each with shape [C_in, F, H, W] + t (Tensor): + Diffusion timesteps tensor of shape [B] + context (List[Tensor]): + List of text embeddings each with shape [L, C] + seq_len (`int`): + Maximum sequence length for positional encoding + clip_fea (Tensor, *optional*): + CLIP image features for image-to-video mode + y (List[Tensor], *optional*): + Conditional video inputs for image-to-video mode, same shape as x + + Returns: + List[Tensor]: + List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8] + """ + if self.model_type == 'i2v': + assert clip_fea is not None and y is not None + # params + device = self.patch_embedding.weight.device + freqs = self.freqs.to(device) + + if y is not None: + x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] + + # embeddings + x = [self.patch_embedding(u.unsqueeze(0)) for u in x] + grid_sizes = torch.stack( + [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) + x = [u.flatten(2).transpose(1, 2) for u in x] + seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) + assert seq_lens.max() <= seq_len + x = torch.cat([ + torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], + dim=1) for u in x + ]) + + # time embeddings + with torch.amp.autocast('cuda', dtype=torch.float32): + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t).float()) + e0 = self.time_projection(e).unflatten(1, (6, self.dim)) + assert e.dtype == torch.float32 and e0.dtype == torch.float32 + + # context + context_lens = None + context = self.text_embedding( + torch.stack([ + torch.cat( + [u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) + for u in context + ])) + + if clip_fea is not None: + context_clip = self.img_emb(clip_fea) # bs x 257 x dim + context = torch.concat([context_clip, context], dim=1) + + # arguments + kwargs = dict( + e=e0, + seq_lens=seq_lens, + grid_sizes=grid_sizes, + freqs=freqs, + context=context, + context_lens=context_lens) + + for block in self.blocks: + x = block(x, **kwargs) + + # head + x = self.head(x, e) + + # unpatchify + x = self.unpatchify(x, grid_sizes) + return [u.float() for u in x] + + def unpatchify(self, x, grid_sizes): + r""" + Reconstruct video tensors from patch embeddings. + + Args: + x (List[Tensor]): + List of patchified features, each with shape [L, C_out * prod(patch_size)] + grid_sizes (Tensor): + Original spatial-temporal grid dimensions before patching, + shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches) + + Returns: + List[Tensor]: + Reconstructed video tensors with shape [C_out, F, H / 8, W / 8] + """ + + c = self.out_dim + out = [] + for u, v in zip(x, grid_sizes.tolist()): + u = u[:math.prod(v)].view(*v, *self.patch_size, c) + u = torch.einsum('fhwpqrc->cfphqwr', u) + u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)]) + out.append(u) + return out + + def init_weights(self): + r""" + Initialize model parameters using Xavier initialization. + """ + + # basic init + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + + # init embeddings + nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1)) + for m in self.text_embedding.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=.02) + for m in self.time_embedding.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=.02) + + # init output layer + nn.init.zeros_(self.head.head.weight) diff --git a/humo/models/wan_modules/model_humo.py b/humo/models/wan_modules/model_humo.py new file mode 100644 index 0000000000000000000000000000000000000000..d7190fe08b49f8c05e7f29dcc956fba4bd7ecaa5 --- /dev/null +++ b/humo/models/wan_modules/model_humo.py @@ -0,0 +1,803 @@ +import torch +from torch import nn + +from common.distributed import get_device +from models.audio.audio_proj import AudioProjModel + +import torch.cuda.amp as amp +import math +from humo.models.wan_modules.attention import flash_attention +from common.distributed.advanced import is_unified_parallel_initialized + +import types + +def sinusoidal_embedding_1d(dim, position): + # preprocess + assert dim % 2 == 0 + half = dim // 2 + position = position.type(torch.float64) + + # calculation + sinusoid = torch.outer( + position, torch.pow(10000, -torch.arange(half).to(position).div(half))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + return x + + +@amp.autocast(enabled=False) +def rope_params(max_seq_len, dim, theta=10000): + assert dim % 2 == 0 + freqs = torch.outer( + torch.arange(max_seq_len), + 1.0 / torch.pow(theta, + torch.arange(0, dim, 2).to(torch.float32).div(dim))) + freqs = torch.polar(torch.ones_like(freqs), freqs) + return freqs + + +@amp.autocast(enabled=False) +def rope_apply(x, grid_sizes, freqs): + n, c = x.size(2), x.size(3) // 2 + + # split freqs + freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) + + # loop over samples + output = [] + for i, (f, h, w) in enumerate(grid_sizes.tolist()): + seq_len = f * h * w + + # precompute multipliers + x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float32).reshape( + seq_len, n, -1, 2)) + freqs_i = torch.cat([ + freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], + dim=-1).reshape(seq_len, 1, -1) + + # apply rotary embedding + x_i = torch.view_as_real(x_i * freqs_i).flatten(2) + x_i = torch.cat([x_i, x[i, seq_len:]]) + + # append to collection + output.append(x_i) + return torch.stack(output).float() + + +class WanRMSNorm(nn.Module): + + def __init__(self, dim, eps=1e-5): + super().__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + r""" + Args: + x(Tensor): Shape [B, L, C] + """ + return self._norm(x.float()).type_as(x) * self.weight + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) + + +class WanLayerNorm(nn.LayerNorm): + + def __init__(self, dim, eps=1e-6, elementwise_affine=False): + super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps) + + def forward(self, x): + r""" + Args: + x(Tensor): Shape [B, L, C] + """ + return super().forward(x.float()).type_as(x) + + +class WanSelfAttention(nn.Module): + + def __init__(self, + dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + eps=1e-6): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.window_size = window_size + self.qk_norm = qk_norm + self.eps = eps + + # layers + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + + def forward(self, x, seq_lens, grid_sizes, freqs): + r""" + Args: + x(Tensor): Shape [B, L, num_heads, C / num_heads], torch.Size([1, 9360, 5120]) + seq_lens(Tensor): Shape [B], tensor([9360]) + grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W), tensor([[ 6, 30, 52]]) + freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + """ + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim + + # query, key, value function + def qkv_fn(x): + q = self.norm_q(self.q(x)).view(b, s, n, d) + k = self.norm_k(self.k(x)).view(b, s, n, d) + v = self.v(x).view(b, s, n, d) + return q, k, v + + q, k, v = qkv_fn(x) + + x = flash_attention( + q=rope_apply(q, grid_sizes, freqs), + k=rope_apply(k, grid_sizes, freqs), + v=v, + k_lens=seq_lens, + window_size=self.window_size) + + # output + x = x.flatten(2) + x = self.o(x) + return x + + +class WanSelfAttentionSepKVDim(nn.Module): + + def __init__(self, + kv_dim, + dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + eps=1e-6): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.window_size = window_size + self.qk_norm = qk_norm + self.eps = eps + + # layers + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(kv_dim, dim) + self.v = nn.Linear(kv_dim, dim) + self.o = nn.Linear(dim, dim) + self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + + def forward(self, x, seq_lens, grid_sizes, freqs): + r""" + Args: + x(Tensor): Shape [B, L, num_heads, C / num_heads], torch.Size([1, 9360, 5120]) + seq_lens(Tensor): Shape [B], tensor([9360]) + grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W), tensor([[ 6, 30, 52]]) + freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + """ + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim + + # query, key, value function + def qkv_fn(x): + q = self.norm_q(self.q(x)).view(b, s, n, d) + k = self.norm_k(self.k(x)).view(b, s, n, d) + v = self.v(x).view(b, s, n, d) + return q, k, v + + q, k, v = qkv_fn(x) + + x = flash_attention( + q=rope_apply(q, grid_sizes, freqs), + k=rope_apply(k, grid_sizes, freqs), + v=v, + k_lens=seq_lens, + window_size=self.window_size) + + # output + x = x.flatten(2) + x = self.o(x) + return x + + + +class WanT2VCrossAttention(WanSelfAttention): + + def forward(self, x, context, context_lens): + r""" + Args: + x(Tensor): Shape [B, L1, C] + context(Tensor): Shape [B, L2, C] + context_lens(Tensor): Shape [B] + """ + b, n, d = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + q = self.norm_q(self.q(x)).view(b, -1, n, d) + k = self.norm_k(self.k(context)).view(b, -1, n, d) + v = self.v(context).view(b, -1, n, d) + + # compute attention + x = flash_attention(q, k, v, k_lens=context_lens) + + # output + x = x.flatten(2) + x = self.o(x) + return x + + +class WanT2VCrossAttentionGather(WanSelfAttentionSepKVDim): + + def forward(self, x, context, context_lens, grid_sizes, freqs, audio_seq_len): + b, n, d = x.size(0), self.num_heads, self.head_dim + + q = self.norm_q(self.q(x)).view(b, -1, n, d) + k = self.norm_k(self.k(context)).view(b, -1, n, d) + v = self.v(context).view(b, -1, n, d) + + # --- NEW: derive sizes from shapes (SymInts), no int(tensor) casts --- + Lq = q.shape[1] # total video tokens per sample + # audio has 16 tokens per frame -> frames = audio_tokens // 16 + frames = (context.shape[1] // 16) + hlen_wlen = Lq // frames # tokens per frame = H*W + + # Now reshape using SymInt-derived sizes + q = q.reshape(-1, hlen_wlen, n, d) + k = k.reshape(-1, 16, n, d) + v = v.reshape(-1, 16, n, d) + + x = flash_attention(q, k, v, k_lens=None) + x = x.view(b, -1, n, d).flatten(2) + x = self.o(x) + return x + + # def forward(self, x, context, context_lens, grid_sizes, freqs, audio_seq_len): + # r""" + # Args: + # x(Tensor): Shape [B, L1, C] - video tokens + # context(Tensor): Shape [B, L2, C] - audio tokens with shape [B, frames*16, 1536] + # context_lens(Tensor): Shape [B] - actually seq_lens from call (video sequence length) + # grid_sizes(Tensor): Shape [B, 3] - video grid dimensions (F, H, W) + # freqs(Tensor): RoPE frequencies + # audio_seq_len(Tensor): Actual audio sequence length (frames * 16) + # """ + # b, n, d = x.size(0), self.num_heads, self.head_dim + + # q = self.norm_q(self.q(x)).view(b, -1, n, d) + # k = self.norm_k(self.k(context)).view(b, -1, n, d) + # v = self.v(context).view(b, -1, n, d) + + # # Handle video spatial structure + # hlen_wlen = int(grid_sizes[0][1] * grid_sizes[0][2]) + # q = q.reshape(-1, hlen_wlen, n, d) + + # # Handle audio temporal structure (16 tokens per frame) + # k = k.reshape(-1, 16, n, d) + # v = v.reshape(-1, 16, n, d) + + # # Cross-attention + # x = flash_attention(q, k, v, k_lens=None) # No masking for audio + + # x = x.view(b, -1, n, d).flatten(2) + # x = self.o(x) + # return x + + +class AudioCrossAttentionWrapper(nn.Module): + def __init__(self, dim, kv_dim, num_heads, qk_norm=True, eps=1e-6,): + super().__init__() + + self.audio_cross_attn = WanT2VCrossAttentionGather( + kv_dim, dim, num_heads, (-1, -1), qk_norm, eps) + self.norm1_audio = WanLayerNorm(dim, eps, + elementwise_affine=True) + + def forward(self, x, audio, seq_lens, grid_sizes, freqs, audio_seq_len): + x = x + self.audio_cross_attn( + self.norm1_audio(x), audio, seq_lens, grid_sizes, freqs, audio_seq_len) + return x + + +class WanI2VCrossAttention(WanSelfAttention): + + def __init__(self, + dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + eps=1e-6): + super().__init__(dim, num_heads, window_size, qk_norm, eps) + + def forward(self, x, context, context_lens): + r""" + Args: + x(Tensor): Shape [B, L1, C] + context(Tensor): Shape [B, L2, C] + context_lens(Tensor): Shape [B] + """ + b, n, d = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + q = self.norm_q(self.q(x)).view(b, -1, n, d) + k = self.norm_k(self.k(context)).view(b, -1, n, d) + v = self.v(context).view(b, -1, n, d) + x = flash_attention(q, k, v, k_lens=context_lens) + + # output + x = x.flatten(2) + x = self.o(x) + return x + + +WAN_CROSSATTENTION_CLASSES = { + 't2v_cross_attn': WanT2VCrossAttention, + 'i2v_cross_attn': WanI2VCrossAttention, +} + +class WanAttentionBlock(nn.Module): + + def __init__(self, + cross_attn_type, + dim, + ffn_dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=False, + eps=1e-6, + use_audio=True): + super().__init__() + self.dim = dim + self.ffn_dim = ffn_dim + self.num_heads = num_heads + self.window_size = window_size + self.qk_norm = qk_norm + self.cross_attn_norm = cross_attn_norm + self.eps = eps + + # layers + self.norm1 = WanLayerNorm(dim, eps) + self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, + eps) + self.norm3 = WanLayerNorm( + dim, eps, + elementwise_affine=True) if cross_attn_norm else nn.Identity() + self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim, + num_heads, + (-1, -1), + qk_norm, + eps) + self.norm2 = WanLayerNorm(dim, eps) + self.ffn = nn.Sequential( + nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'), + nn.Linear(ffn_dim, dim)) + + # modulation + self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + + self.use_audio = use_audio + if use_audio: + self.audio_cross_attn_wrapper = AudioCrossAttentionWrapper(dim, 1536, num_heads, qk_norm, eps) + + def forward( + self, + x, # torch.Size([1, 9360, 5120]) + e, # torch.Size([1, 6, 5120]) + seq_lens, # tensor([9360]) + grid_sizes, # tensor([[ 6, 30, 52]]) + freqs, # torch.Size([1024, 64]) + context, # torch.Size([1, 512, 5120]) + context_lens, # None + audio=None, # None + audio_seq_len=None, + ref_num_list=None, + ): + r""" + Args: + x(Tensor): Shape [B, L, C] + e(Tensor): Shape [B, L, C] + audio(Tensor): Shape [B, L, C] + seq_lens(Tensor): Shape [B], length of each sequence in batch + grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) + freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + ref_num_list: 配合seq_lens可以查到reference image在倒数第几个 + """ + assert e.dtype == torch.float32 + with torch.amp.autocast('cuda', dtype=torch.float32): + e = (self.modulation + e).chunk(6, dim=1) + assert e[0].dtype == torch.float32 + + # self-attention + y = self.self_attn( + self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes, + freqs) + with torch.amp.autocast('cuda', dtype=torch.float32): + x = x + y * e[2] + + # cross-attention & ffn function + def cross_attn_ffn(x, context, context_lens, e): + x = x + self.cross_attn(self.norm3(x), context, context_lens) + + if self.use_audio: + x = self.audio_cross_attn_wrapper(x, audio, seq_lens, grid_sizes, freqs, audio_seq_len) + + y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3]) + with torch.amp.autocast('cuda', dtype=torch.float32): + x = x + y * e[5] + return x + + x = cross_attn_ffn(x, context, context_lens, e) + + return x + + +class Head(nn.Module): + + def __init__(self, dim, out_dim, patch_size, eps=1e-6): + super().__init__() + self.dim = dim + self.out_dim = out_dim + self.patch_size = patch_size + self.eps = eps + + # layers + out_dim = math.prod(patch_size) * out_dim + self.norm = WanLayerNorm(dim, eps) + self.head = nn.Linear(dim, out_dim) + + # modulation + self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) + + def forward(self, x, e): + r""" + Args: + x(Tensor): Shape [B, L1, C] + e(Tensor): Shape [B, C] + """ + assert e.dtype == torch.float32 + with torch.amp.autocast('cuda', dtype=torch.float32): + e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1) + x = (self.head(self.norm(x) * (1 + e[1]) + e[0])) + return x + + +class MLPProj(torch.nn.Module): + + def __init__(self, in_dim, out_dim): + super().__init__() + + self.proj = torch.nn.Sequential( + torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim), + torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim), + torch.nn.LayerNorm(out_dim)) + + def forward(self, image_embeds): + clip_extra_context_tokens = self.proj(image_embeds) + return clip_extra_context_tokens + + +class WanModel(nn.Module): + r""" + Wan diffusion backbone supporting both text-to-video and image-to-video. + """ + + ignore_for_config = [ + 'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size' + ] + _no_split_modules = ['WanAttentionBlock'] + + gradient_checkpointing = False + + def __init__(self, + model_type='t2v', + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2048, + ffn_dim=13824, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=40, + num_layers=40, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + eps=1e-6, + audio_token_num=16, + insert_audio=True): + r""" + Initialize the diffusion model backbone. + + Args: + model_type (`str`, *optional*, defaults to 't2v'): + Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) + patch_size (`tuple`, *optional*, defaults to (1, 2, 2)): + 3D patch dimensions for video embedding (t_patch, h_patch, w_patch) + text_len (`int`, *optional*, defaults to 512): + Fixed length for text embeddings + in_dim (`int`, *optional*, defaults to 16): + Input video channels (C_in) + dim (`int`, *optional*, defaults to 2048): + Hidden dimension of the transformer + ffn_dim (`int`, *optional*, defaults to 8192): + Intermediate dimension in feed-forward network + freq_dim (`int`, *optional*, defaults to 256): + Dimension for sinusoidal time embeddings + text_dim (`int`, *optional*, defaults to 4096): + Input dimension for text embeddings + out_dim (`int`, *optional*, defaults to 16): + Output video channels (C_out) + num_heads (`int`, *optional*, defaults to 16): + Number of attention heads + num_layers (`int`, *optional*, defaults to 32): + Number of transformer blocks + window_size (`tuple`, *optional*, defaults to (-1, -1)): + Window size for local attention (-1 indicates global attention) + qk_norm (`bool`, *optional*, defaults to True): + Enable query/key normalization + cross_attn_norm (`bool`, *optional*, defaults to False): + Enable cross-attention normalization + eps (`float`, *optional*, defaults to 1e-6): + Epsilon value for normalization layers + """ + + super().__init__() + + assert model_type in ['t2v', 'i2v'] + self.model_type = model_type + + self.patch_size = patch_size + self.text_len = text_len + self.in_dim = in_dim + self.dim = dim + self.ffn_dim = ffn_dim + self.freq_dim = freq_dim + self.text_dim = text_dim + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.window_size = window_size + self.qk_norm = qk_norm + self.cross_attn_norm = cross_attn_norm + self.eps = eps + + # embeddings + self.patch_embedding = nn.Conv3d( + in_dim, dim, kernel_size=patch_size, stride=patch_size) + self.text_embedding = nn.Sequential( + nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'), + nn.Linear(dim, dim)) + + self.time_embedding = nn.Sequential( + nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) + self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6)) + + # blocks + cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn' + self.insert_audio = insert_audio + self.blocks = nn.ModuleList([ + WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads, + window_size, qk_norm, cross_attn_norm, + eps, use_audio=self.insert_audio) + for _ in range(num_layers) + ]) + + # head + self.head = Head(dim, out_dim, patch_size, eps) + + if self.insert_audio: + self.audio_proj = AudioProjModel(seq_len=8, blocks=5, channels=1280, + intermediate_dim=512, output_dim=1536, context_tokens=audio_token_num) + + # RoPE freqs: register as a buffer so it moves with .to() / DDP and is tracked by compile + assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0 + d = dim // num_heads + + _freqs = torch.cat([ + rope_params(1024, d - 4 * (d // 6)), + rope_params(1024, 2 * (d // 6)), + rope_params(1024, 2 * (d // 6)) + ], dim=1) + self.register_buffer("freqs", _freqs, persistent=False) + + # initialize weights + self.init_weights() + + # initialize unified parallel + if is_unified_parallel_initialized(): + print(f"Initializing WanModel with unified parallel initialized") + from humo.models.distributed.dit_ulysses_sequence_parallel import ulysses_attn_forward, ulysses_dit_forward, ulysses_audio_cross_attn_forward + for block in self.blocks: + block.self_attn.forward = types.MethodType(ulysses_attn_forward, block.self_attn) + if block.use_audio: + block.audio_cross_attn_wrapper.audio_cross_attn.forward = types.MethodType(ulysses_audio_cross_attn_forward, block.audio_cross_attn_wrapper.audio_cross_attn) + self.forward = types.MethodType(ulysses_dit_forward, self) + + def forward( + self, + x, + t, + context, + seq_len, + audio=None, + y=None, + tea_cache=None, + ): + r""" + Forward pass through the diffusion model + + Args: + x (List[Tensor]): + List of input video tensors, each with shape [C_in, F, H, W] + t (Tensor): + Diffusion timesteps tensor of shape [B] + context (List[Tensor]): + List of text embeddings each with shape [L, C] + seq_len (`int`): + Maximum sequence length for positional encoding + clip_fea (Tensor, *optional*): + CLIP image features for image-to-video mode + y (List[Tensor], *optional*): + Conditional video inputs for image-to-video mode, same shape as x + + Returns: + List[Tensor]: + List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8] + """ + if self.model_type == 'i2v': + assert y is not None + + # params + freqs = self.freqs + + if y is not None: + x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] + + # embeddings + x = [self.patch_embedding(u.unsqueeze(0)) for u in x] + grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) + + x = [u.flatten(2).transpose(1, 2) for u in x] + seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) + assert seq_lens.max() <= seq_len + + # pad to uniform length and batch + x = torch.cat([ + torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) + for u in x + ]) # shape: [B, seq_len, C] + + # time embeddings + with torch.amp.autocast('cuda', dtype=torch.float32): + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t).float() + ).float() + e0 = self.time_projection(e).unflatten(1, (6, self.dim)).float() + assert e.dtype == torch.float32 and e0.dtype == torch.float32 + + # context + context_lens = None + context = self.text_embedding( + torch.stack([ + torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) + for u in context + ]) + ) + + # audio (unchanged; not cached) + if self.insert_audio: + audio = [self.audio_proj(au.unsqueeze(0)).permute(0, 3, 1, 2) for au in audio] + audio_seq_len = max(au.shape[2] for au in audio) * audio[0].shape[3] + + audio = [au.flatten(2).transpose(1, 2) for au in audio] # [1, t*32, 1536] + audio = torch.cat([ + torch.cat([au, au.new_zeros(1, int(audio_seq_len) - au.size(1), au.size(2))], dim=1) + for au in audio + ]) + else: + audio = None + audio_seq_len = None + + # ---- tea_cache integration (mirrors your working model) ---- + if tea_cache is not None: + # Use the pre-block tokens 'x' and time-mod 'e0' to decide whether to reuse cache + tea_cache_update = tea_cache.check(self, x, e0) + else: + tea_cache_update = False + + ori_x_len = x.shape[1] # remember original token length before potential cache extension + + if tea_cache_update: + # Let the cache inject/append any needed past states/tokens for reuse + x = tea_cache.update(x) + else: + # arguments for blocks + kwargs = dict( + e=e0, + seq_lens=seq_lens, + grid_sizes=grid_sizes, + freqs=freqs, + context=context, + context_lens=context_lens, + audio=audio, + audio_seq_len=audio_seq_len + ) + + # transformer blocks + for block in self.blocks: + x = block(x, **kwargs) + + if tea_cache is not None: + x_cache = x[:, :ori_x_len] + tea_cache.store(x_cache) + + # head + x = self.head(x, e) + + # unpatchify + x = self.unpatchify(x, grid_sizes) + return [u.float() for u in x] + + + def unpatchify(self, x, grid_sizes): + r""" + Reconstruct video tensors from patch embeddings. + + Args: + x (List[Tensor]): + List of patchified features, each with shape [L, C_out * prod(patch_size)] + grid_sizes (Tensor): + Original spatial-temporal grid dimensions before patching, + shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches) + + Returns: + List[Tensor]: + Reconstructed video tensors with shape [C_out, F, H / 8, W / 8] + """ + + c = self.out_dim + out = [] + for u, v in zip(x, grid_sizes.tolist()): + u = u[:math.prod(v)].view(*v, *self.patch_size, c) + u = torch.einsum('fhwpqrc->cfphqwr', u) + u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)]) + out.append(u) + return out + + def init_weights(self): + r""" + Initialize model parameters using Xavier initialization. + """ + + # basic init + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + + # init embeddings + nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1)) + for m in self.text_embedding.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=.02) + for m in self.time_embedding.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=.02) + + # init output layer + nn.init.zeros_(self.head.head.weight) diff --git a/humo/models/wan_modules/t5.py b/humo/models/wan_modules/t5.py new file mode 100644 index 0000000000000000000000000000000000000000..c6005c88e95208881064bc8b3bf6fe8df6c4126f --- /dev/null +++ b/humo/models/wan_modules/t5.py @@ -0,0 +1,525 @@ +# Modified from transformers.models.t5.modeling_t5 +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import logging +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .tokenizers import HuggingfaceTokenizer + +__all__ = [ + 'T5Model', + 'T5Encoder', + 'T5Decoder', + 'T5EncoderModel', +] + + +def fp16_clamp(x): + if x.dtype == torch.float16 and torch.isinf(x).any(): + clamp = torch.finfo(x.dtype).max - 1000 + x = torch.clamp(x, min=-clamp, max=clamp) + return x + + +def init_weights(m): + if isinstance(m, T5LayerNorm): + nn.init.ones_(m.weight) + elif isinstance(m, T5Model): + nn.init.normal_(m.token_embedding.weight, std=1.0) + elif isinstance(m, T5FeedForward): + nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5) + nn.init.normal_(m.fc1.weight, std=m.dim**-0.5) + nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5) + elif isinstance(m, T5Attention): + nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5) + nn.init.normal_(m.k.weight, std=m.dim**-0.5) + nn.init.normal_(m.v.weight, std=m.dim**-0.5) + nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5) + elif isinstance(m, T5RelativeEmbedding): + nn.init.normal_( + m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5) + + +class GELU(nn.Module): + + def forward(self, x): + return 0.5 * x * (1.0 + torch.tanh( + math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) + + +class T5LayerNorm(nn.Module): + + def __init__(self, dim, eps=1e-6): + super(T5LayerNorm, self).__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + + self.eps) + if self.weight.dtype in [torch.float16, torch.bfloat16]: + x = x.type_as(self.weight) + return self.weight * x + + +class T5Attention(nn.Module): + + def __init__(self, dim, dim_attn, num_heads, dropout=0.1): + assert dim_attn % num_heads == 0 + super(T5Attention, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.num_heads = num_heads + self.head_dim = dim_attn // num_heads + + # layers + self.q = nn.Linear(dim, dim_attn, bias=False) + self.k = nn.Linear(dim, dim_attn, bias=False) + self.v = nn.Linear(dim, dim_attn, bias=False) + self.o = nn.Linear(dim_attn, dim, bias=False) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, context=None, mask=None, pos_bias=None): + """ + x: [B, L1, C]. + context: [B, L2, C] or None. + mask: [B, L2] or [B, L1, L2] or None. + """ + # check inputs + context = x if context is None else context + b, n, c = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + q = self.q(x).view(b, -1, n, c) + k = self.k(context).view(b, -1, n, c) + v = self.v(context).view(b, -1, n, c) + + # attention bias + attn_bias = x.new_zeros(b, n, q.size(1), k.size(1)) + if pos_bias is not None: + attn_bias += pos_bias + if mask is not None: + assert mask.ndim in [2, 3] + mask = mask.view(b, 1, 1, + -1) if mask.ndim == 2 else mask.unsqueeze(1) + attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min) + + # compute attention (T5 does not use scaling) + attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias + attn = F.softmax(attn.float(), dim=-1).type_as(attn) + x = torch.einsum('bnij,bjnc->binc', attn, v) + + # output + x = x.reshape(b, -1, n * c) + x = self.o(x) + x = self.dropout(x) + return x + + +class T5FeedForward(nn.Module): + + def __init__(self, dim, dim_ffn, dropout=0.1): + super(T5FeedForward, self).__init__() + self.dim = dim + self.dim_ffn = dim_ffn + + # layers + self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU()) + self.fc1 = nn.Linear(dim, dim_ffn, bias=False) + self.fc2 = nn.Linear(dim_ffn, dim, bias=False) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + x = self.fc1(x) * self.gate(x) + x = self.dropout(x) + x = self.fc2(x) + x = self.dropout(x) + return x + + +class T5SelfAttention(nn.Module): + + def __init__(self, + dim, + dim_attn, + dim_ffn, + num_heads, + num_buckets, + shared_pos=True, + dropout=0.1): + super(T5SelfAttention, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.norm1 = T5LayerNorm(dim) + self.attn = T5Attention(dim, dim_attn, num_heads, dropout) + self.norm2 = T5LayerNorm(dim) + self.ffn = T5FeedForward(dim, dim_ffn, dropout) + self.pos_embedding = None if shared_pos else T5RelativeEmbedding( + num_buckets, num_heads, bidirectional=True) + + def forward(self, x, mask=None, pos_bias=None): + e = pos_bias if self.shared_pos else self.pos_embedding( + x.size(1), x.size(1)) + x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e)) + x = fp16_clamp(x + self.ffn(self.norm2(x))) + return x + + +class T5CrossAttention(nn.Module): + + def __init__(self, + dim, + dim_attn, + dim_ffn, + num_heads, + num_buckets, + shared_pos=True, + dropout=0.1): + super(T5CrossAttention, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.norm1 = T5LayerNorm(dim) + self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout) + self.norm2 = T5LayerNorm(dim) + self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout) + self.norm3 = T5LayerNorm(dim) + self.ffn = T5FeedForward(dim, dim_ffn, dropout) + self.pos_embedding = None if shared_pos else T5RelativeEmbedding( + num_buckets, num_heads, bidirectional=False) + + def forward(self, + x, + mask=None, + encoder_states=None, + encoder_mask=None, + pos_bias=None): + e = pos_bias if self.shared_pos else self.pos_embedding( + x.size(1), x.size(1)) + x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e)) + x = fp16_clamp(x + self.cross_attn( + self.norm2(x), context=encoder_states, mask=encoder_mask)) + x = fp16_clamp(x + self.ffn(self.norm3(x))) + return x + + +class T5RelativeEmbedding(nn.Module): + + def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128): + super(T5RelativeEmbedding, self).__init__() + self.num_buckets = num_buckets + self.num_heads = num_heads + self.bidirectional = bidirectional + self.max_dist = max_dist + + # layers + self.embedding = nn.Embedding(num_buckets, num_heads) + + def forward(self, lq, lk): + device = self.embedding.weight.device + # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \ + # torch.arange(lq).unsqueeze(1).to(device) + rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \ + torch.arange(lq, device=device).unsqueeze(1) + rel_pos = self._relative_position_bucket(rel_pos) + rel_pos_embeds = self.embedding(rel_pos) + rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze( + 0) # [1, N, Lq, Lk] + return rel_pos_embeds.contiguous() + + def _relative_position_bucket(self, rel_pos): + # preprocess + if self.bidirectional: + num_buckets = self.num_buckets // 2 + rel_buckets = (rel_pos > 0).long() * num_buckets + rel_pos = torch.abs(rel_pos) + else: + num_buckets = self.num_buckets + rel_buckets = 0 + rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos)) + + # embeddings for small and large positions + max_exact = num_buckets // 2 + rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) / + math.log(self.max_dist / max_exact) * + (num_buckets - max_exact)).long() + rel_pos_large = torch.min( + rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1)) + rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large) + return rel_buckets + + +class T5Encoder(nn.Module): + + def __init__(self, + vocab, + dim, + dim_attn, + dim_ffn, + num_heads, + num_layers, + num_buckets, + shared_pos=True, + dropout=0.1): + super(T5Encoder, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_layers = num_layers + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \ + else nn.Embedding(vocab, dim) + self.pos_embedding = T5RelativeEmbedding( + num_buckets, num_heads, bidirectional=True) if shared_pos else None + self.dropout = nn.Dropout(dropout) + self.blocks = nn.ModuleList([ + T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, + shared_pos, dropout) for _ in range(num_layers) + ]) + self.norm = T5LayerNorm(dim) + + # initialize weights + self.apply(init_weights) + + def forward(self, ids, mask=None): + x = self.token_embedding(ids) + x = self.dropout(x) + e = self.pos_embedding(x.size(1), + x.size(1)) if self.shared_pos else None + for block in self.blocks: + x = block(x, mask, pos_bias=e) + x = self.norm(x) + x = self.dropout(x) + return x + + +class T5Decoder(nn.Module): + + def __init__(self, + vocab, + dim, + dim_attn, + dim_ffn, + num_heads, + num_layers, + num_buckets, + shared_pos=True, + dropout=0.1): + super(T5Decoder, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_layers = num_layers + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \ + else nn.Embedding(vocab, dim) + self.pos_embedding = T5RelativeEmbedding( + num_buckets, num_heads, bidirectional=False) if shared_pos else None + self.dropout = nn.Dropout(dropout) + self.blocks = nn.ModuleList([ + T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, + shared_pos, dropout) for _ in range(num_layers) + ]) + self.norm = T5LayerNorm(dim) + + # initialize weights + self.apply(init_weights) + + def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None): + b, s = ids.size() + + # causal mask + if mask is None: + mask = torch.tril(torch.ones(1, s, s).to(ids.device)) + elif mask.ndim == 2: + mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1)) + + # layers + x = self.token_embedding(ids) + x = self.dropout(x) + e = self.pos_embedding(x.size(1), + x.size(1)) if self.shared_pos else None + for block in self.blocks: + x = block(x, mask, encoder_states, encoder_mask, pos_bias=e) + x = self.norm(x) + x = self.dropout(x) + return x + + +class T5Model(nn.Module): + + def __init__(self, + vocab_size, + dim, + dim_attn, + dim_ffn, + num_heads, + encoder_layers, + decoder_layers, + num_buckets, + shared_pos=True, + dropout=0.1): + super(T5Model, self).__init__() + self.vocab_size = vocab_size + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.encoder_layers = encoder_layers + self.decoder_layers = decoder_layers + self.num_buckets = num_buckets + + # layers + self.token_embedding = nn.Embedding(vocab_size, dim) + self.encoder = T5Encoder(self.token_embedding, dim, dim_attn, dim_ffn, + num_heads, encoder_layers, num_buckets, + shared_pos, dropout) + self.decoder = T5Decoder(self.token_embedding, dim, dim_attn, dim_ffn, + num_heads, decoder_layers, num_buckets, + shared_pos, dropout) + self.head = nn.Linear(dim, vocab_size, bias=False) + + # initialize weights + self.apply(init_weights) + + def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask): + x = self.encoder(encoder_ids, encoder_mask) + x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask) + x = self.head(x) + return x + + +def _t5(name, + encoder_only=False, + decoder_only=False, + return_tokenizer=False, + tokenizer_kwargs={}, + dtype=torch.float32, + device='cpu', + **kwargs): + # sanity check + assert not (encoder_only and decoder_only) + + # params + if encoder_only: + model_cls = T5Encoder + kwargs['vocab'] = kwargs.pop('vocab_size') + kwargs['num_layers'] = kwargs.pop('encoder_layers') + _ = kwargs.pop('decoder_layers') + elif decoder_only: + model_cls = T5Decoder + kwargs['vocab'] = kwargs.pop('vocab_size') + kwargs['num_layers'] = kwargs.pop('decoder_layers') + _ = kwargs.pop('encoder_layers') + else: + model_cls = T5Model + + # init model + with torch.device(device): + model = model_cls(**kwargs) + + # set device + model = model.to(dtype=dtype, device=device) + + # init tokenizer + if return_tokenizer: + from .tokenizers import HuggingfaceTokenizer + tokenizer = HuggingfaceTokenizer(f'google/{name}', **tokenizer_kwargs) + return model, tokenizer + else: + return model + + +def umt5_xxl(**kwargs): + cfg = dict( + vocab_size=256384, + dim=4096, + dim_attn=4096, + dim_ffn=10240, + num_heads=64, + encoder_layers=24, + decoder_layers=24, + num_buckets=32, + shared_pos=False, + dropout=0.1) + cfg.update(**kwargs) + return _t5('umt5-xxl', **cfg) + + +class T5EncoderModel(nn.Module): + + def __init__( + self, + text_len, + dtype=torch.bfloat16, + device=torch.cuda.current_device(), + checkpoint_path=None, + tokenizer_path=None, + shard_fn=None, + ): + super(T5EncoderModel, self).__init__() + self.text_len = text_len + self.dtype = dtype + self.device = device + self.checkpoint_path = checkpoint_path + self.tokenizer_path = tokenizer_path + + with torch.device(device): + self.model = T5Encoder( + vocab=256384, + dim=4096, + dim_attn=4096, + dim_ffn=10240, + num_heads=64, + num_layers=24, + num_buckets=32, + shared_pos=False, + dropout=0.1 + ) + # set device + self.model = self.model.to(dtype=dtype, device=device).eval().requires_grad_(False) + + logging.info(f'loading {checkpoint_path}') + if checkpoint_path is not None: + self.model.load_state_dict(torch.load(checkpoint_path, map_location='cpu')) + + if shard_fn is not None: + self.model = shard_fn(self.model, sync_module_states=False) + else: + self.model.to(self.device) + # init tokenizer + self.tokenizer = HuggingfaceTokenizer( + name=tokenizer_path, seq_len=text_len, clean='whitespace') + + @torch.no_grad() + def __call__(self, texts, device): + ids, mask = self.tokenizer( + texts, return_mask=True, add_special_tokens=True) + ids = ids.to(device) + mask = mask.to(device) + seq_lens = mask.gt(0).sum(dim=1).long() + context = self.model(ids, mask) + return [u[:v] for u, v in zip(context, seq_lens)] diff --git a/humo/models/wan_modules/tokenizers.py b/humo/models/wan_modules/tokenizers.py new file mode 100644 index 0000000000000000000000000000000000000000..329c2add418c49c5df1c589c45dd124a59caafc7 --- /dev/null +++ b/humo/models/wan_modules/tokenizers.py @@ -0,0 +1,82 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import html +import string + +import ftfy +import regex as re +from transformers import AutoTokenizer + +__all__ = ['HuggingfaceTokenizer'] + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +def canonicalize(text, keep_punctuation_exact_string=None): + text = text.replace('_', ' ') + if keep_punctuation_exact_string: + text = keep_punctuation_exact_string.join( + part.translate(str.maketrans('', '', string.punctuation)) + for part in text.split(keep_punctuation_exact_string)) + else: + text = text.translate(str.maketrans('', '', string.punctuation)) + text = text.lower() + text = re.sub(r'\s+', ' ', text) + return text.strip() + + +class HuggingfaceTokenizer: + + def __init__(self, name, seq_len=None, clean=None, **kwargs): + assert clean in (None, 'whitespace', 'lower', 'canonicalize') + self.name = name + self.seq_len = seq_len + self.clean = clean + + # init tokenizer + self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs) + self.vocab_size = self.tokenizer.vocab_size + + def __call__(self, sequence, **kwargs): + return_mask = kwargs.pop('return_mask', False) + + # arguments + _kwargs = {'return_tensors': 'pt'} + if self.seq_len is not None: + _kwargs.update({ + 'padding': 'max_length', + 'truncation': True, + 'max_length': self.seq_len + }) + _kwargs.update(**kwargs) + + # tokenization + if isinstance(sequence, str): + sequence = [sequence] + if self.clean: + sequence = [self._clean(u) for u in sequence] + ids = self.tokenizer(sequence, **_kwargs) + + # output + if return_mask: + return ids.input_ids, ids.attention_mask + else: + return ids.input_ids + + def _clean(self, text): + if self.clean == 'whitespace': + text = whitespace_clean(basic_clean(text)) + elif self.clean == 'lower': + text = whitespace_clean(basic_clean(text)).lower() + elif self.clean == 'canonicalize': + text = canonicalize(basic_clean(text)) + return text diff --git a/humo/models/wan_modules/vae.py b/humo/models/wan_modules/vae.py new file mode 100644 index 0000000000000000000000000000000000000000..33b0ba82f5ce451fbecd8a0deaae19ffa8ca4b38 --- /dev/null +++ b/humo/models/wan_modules/vae.py @@ -0,0 +1,666 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import logging + +import torch +import torch.cuda.amp as amp +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +__all__ = [ + 'WanVAE', +] + +CACHE_T = 2 + + +class CausalConv3d(nn.Conv3d): + """ + Causal 3d convolusion. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._padding = (self.padding[2], self.padding[2], self.padding[1], + self.padding[1], 2 * self.padding[0], 0) + self.padding = (0, 0, 0) + + def forward(self, x, cache_x=None): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + x = F.pad(x, padding) + + return super().forward(x) + + +class RMS_norm(nn.Module): + + def __init__(self, dim, channel_first=True, images=True, bias=False): + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0. + + def forward(self, x): + return F.normalize( + x, dim=(1 if self.channel_first else + -1)) * self.scale * self.gamma + self.bias + + +class Upsample(nn.Upsample): + + def forward(self, x): + """ + Fix bfloat16 support for nearest neighbor interpolation. + """ + return super().forward(x.float()).type_as(x) + + +class Resample(nn.Module): + + def __init__(self, dim, mode): + assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d', + 'downsample3d') + super().__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == 'upsample2d': + self.resample = nn.Sequential( + Upsample(scale_factor=(2., 2.), mode='nearest-exact'), + nn.Conv2d(dim, dim // 2, 3, padding=1)) + elif mode == 'upsample3d': + self.resample = nn.Sequential( + Upsample(scale_factor=(2., 2.), mode='nearest-exact'), + nn.Conv2d(dim, dim // 2, 3, padding=1)) + self.time_conv = CausalConv3d( + dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + + elif mode == 'downsample2d': + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), + nn.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == 'downsample3d': + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), + nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.time_conv = CausalConv3d( + dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + + else: + self.resample = nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == 'upsample3d': + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = 'Rep' + feat_idx[0] += 1 + else: + + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[ + idx] is not None and feat_cache[idx] != 'Rep': + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + if cache_x.shape[2] < 2 and feat_cache[ + idx] is not None and feat_cache[idx] == 'Rep': + cache_x = torch.cat([ + torch.zeros_like(cache_x).to(cache_x.device), + cache_x + ], + dim=2) + if feat_cache[idx] == 'Rep': + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), + 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = rearrange(x, 'b c t h w -> (b t) c h w') + x = self.resample(x) + x = rearrange(x, '(b t) c h w -> b c t h w', t=t) + + if self.mode == 'downsample3d': + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + + cache_x = x[:, :, -1:, :, :].clone() + # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep': + # # cache last frame of last two chunk + # cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + + x = self.time_conv( + torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + + def init_weight(self, conv): + conv_weight = conv.weight + nn.init.zeros_(conv_weight) + c1, c2, t, h, w = conv_weight.size() + one_matrix = torch.eye(c1, c2) + init_matrix = one_matrix + nn.init.zeros_(conv_weight) + #conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5 + conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5 + conv.weight.data.copy_(conv_weight) + nn.init.zeros_(conv.bias.data) + + def init_weight2(self, conv): + conv_weight = conv.weight.data + nn.init.zeros_(conv_weight) + c1, c2, t, h, w = conv_weight.size() + init_matrix = torch.eye(c1 // 2, c2) + #init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2) + conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix + conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix + conv.weight.data.copy_(conv_weight) + nn.init.zeros_(conv.bias.data) + + +class ResidualBlock(nn.Module): + + def __init__(self, in_dim, out_dim, dropout=0.0): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # layers + self.residual = nn.Sequential( + RMS_norm(in_dim, images=False), nn.SiLU(), + CausalConv3d(in_dim, out_dim, 3, padding=1), + RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout), + CausalConv3d(out_dim, out_dim, 3, padding=1)) + self.shortcut = CausalConv3d(in_dim, out_dim, 1) \ + if in_dim != out_dim else nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + h = self.shortcut(x) + for layer in self.residual: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + h + + +class AttentionBlock(nn.Module): + """ + Causal self-attention with a single head. + """ + + def __init__(self, dim): + super().__init__() + self.dim = dim + + # layers + self.norm = RMS_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.proj.weight) + + def forward(self, x): + identity = x + b, c, t, h, w = x.size() + x = rearrange(x, 'b c t h w -> (b t) c h w') + x = self.norm(x) + # compute query, key, value + q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, + -1).permute(0, 1, 3, + 2).contiguous().chunk( + 3, dim=-1) + + # apply attention + x = F.scaled_dot_product_attention( + q, + k, + v, + ) + x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w) + + # output + x = self.proj(x) + x = rearrange(x, '(b t) c h w-> b c t h w', t=t) + return x + identity + + +class Encoder3d(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv1 = CausalConv3d(3, dims[0], 3, padding=1) + + # downsample blocks + downsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + for _ in range(num_res_blocks): + downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + downsamples.append(AttentionBlock(out_dim)) + in_dim = out_dim + + # downsample block + if i != len(dim_mult) - 1: + mode = 'downsample3d' if temperal_downsample[ + i] else 'downsample2d' + downsamples.append(Resample(out_dim, mode=mode)) + scale /= 2.0 + self.downsamples = nn.Sequential(*downsamples) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim), + ResidualBlock(out_dim, out_dim, dropout)) + + # output blocks + self.head = nn.Sequential( + RMS_norm(out_dim, images=False), nn.SiLU(), + CausalConv3d(out_dim, z_dim, 3, padding=1)) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + ## downsamples + for layer in self.downsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## middle + for layer in self.middle: + if isinstance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## head + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + +class Decoder3d(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_upsample=[False, True, True], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2**(len(dim_mult) - 2) + + # init block + self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]), + ResidualBlock(dims[0], dims[0], dropout)) + + # upsample blocks + upsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + if i == 1 or i == 2 or i == 3: + in_dim = in_dim // 2 + for _ in range(num_res_blocks + 1): + upsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + upsamples.append(AttentionBlock(out_dim)) + in_dim = out_dim + + # upsample block + if i != len(dim_mult) - 1: + mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d' + upsamples.append(Resample(out_dim, mode=mode)) + scale *= 2.0 + self.upsamples = nn.Sequential(*upsamples) + + # output blocks + self.head = nn.Sequential( + RMS_norm(out_dim, images=False), nn.SiLU(), + CausalConv3d(out_dim, 3, 3, padding=1)) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + ## conv1 + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + ## middle + for layer in self.middle: + if isinstance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## upsamples + for layer in self.upsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## head + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + +def count_conv3d(model): + count = 0 + for m in model.modules(): + if isinstance(m, CausalConv3d): + count += 1 + return count + + +class WanVAE_(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.temperal_upsample = temperal_downsample[::-1] + + # modules + self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, + attn_scales, self.temperal_downsample, dropout) + self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) + self.conv2 = CausalConv3d(z_dim, z_dim, 1) + self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, + attn_scales, self.temperal_upsample, dropout) + + def forward(self, x): + mu, log_var = self.encode(x) + z = self.reparameterize(mu, log_var) + x_recon = self.decode(z) + return x_recon, mu, log_var + + def encode(self, x, scale): + self.clear_cache() + ## cache + t = x.shape[2] + iter_ = 1 + (t - 1) // 4 + ## 对encode输入的x,按时间拆分为1、4、4、4.... + for i in range(iter_): + self._enc_conv_idx = [0] + if i == 0: + out = self.encoder( + x[:, :, :1, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx) + else: + out_ = self.encoder( + x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx) + out = torch.cat([out, out_], 2) + mu, log_var = self.conv1(out).chunk(2, dim=1) + if isinstance(scale[0], torch.Tensor): + mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view( + 1, self.z_dim, 1, 1, 1) + else: + mu = (mu - scale[0]) * scale[1] + self.clear_cache() + return mu + + def decode(self, z, scale): + self.clear_cache() + # z: [b,c,t,h,w] + if isinstance(scale[0], torch.Tensor): + z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view( + 1, self.z_dim, 1, 1, 1) + else: + z = z / scale[1] + scale[0] + iter_ = z.shape[2] + x = self.conv2(z) + for i in range(iter_): + self._conv_idx = [0] + if i == 0: + out = self.decoder( + x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx) + else: + out_ = self.decoder( + x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx) + out = torch.cat([out, out_], 2) + self.clear_cache() + return out + + def reparameterize(self, mu, log_var): + std = torch.exp(0.5 * log_var) + eps = torch.randn_like(std) + return eps * std + mu + + def sample(self, imgs, deterministic=False): + mu, log_var = self.encode(imgs) + if deterministic: + return mu + std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0)) + return mu + std * torch.randn_like(std) + + def clear_cache(self): + self._conv_num = count_conv3d(self.decoder) + self._conv_idx = [0] + self._feat_map = [None] * self._conv_num + #cache encode + self._enc_conv_num = count_conv3d(self.encoder) + self._enc_conv_idx = [0] + self._enc_feat_map = [None] * self._enc_conv_num + + +def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs): + """ + Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL. + """ + # params + cfg = dict( + dim=96, + z_dim=z_dim, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[False, True, True], + dropout=0.0) + cfg.update(**kwargs) + + # init model + # with torch.device('meta'): + model = WanVAE_(**cfg) + + # load checkpoint + logging.info(f'loading {pretrained_path}') + if pretrained_path is not None: + model.load_state_dict(torch.load(pretrained_path, map_location=device), assign=True) + + return model + + +class WanVAE: + + def __init__(self, + z_dim=16, + vae_pth=None, + dtype=torch.float, + device="cuda"): + self.dtype = dtype + self.device = device + + mean = [ + -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, + 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921 + ] + std = [ + 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, + 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160 + ] + self.mean = torch.tensor(mean, dtype=dtype, device=device) + self.std = torch.tensor(std, dtype=dtype, device=device) + self.scale = [self.mean, 1.0 / self.std] + + # init model + self.model = _video_vae( + pretrained_path=vae_pth, + z_dim=z_dim, + ).eval().requires_grad_(False).to(device) + + @torch.no_grad() + def encode(self, videos, device): + """ + videos: A list of videos each with shape [C, T, H, W]. + """ + + with torch.amp.autocast('cuda', dtype=self.dtype): + return [ + self.model.encode(u.unsqueeze(0).to(device,self.dtype), self.scale).float().squeeze(0) + for u in videos + ] + + @torch.no_grad() + def decode(self, zs): + with torch.amp.autocast('cuda', dtype=self.dtype): + return [ + self.model.decode(u.unsqueeze(0), + self.scale).float().clamp_(-1, 1).squeeze(0) + for u in zs + ] diff --git a/humo/models/wan_modules/xlm_roberta.py b/humo/models/wan_modules/xlm_roberta.py new file mode 100644 index 0000000000000000000000000000000000000000..34858de961e1033ad120c13b1a0342f84f4907f6 --- /dev/null +++ b/humo/models/wan_modules/xlm_roberta.py @@ -0,0 +1,170 @@ +# Modified from transformers.models.xlm_roberta.modeling_xlm_roberta +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ['XLMRoberta', 'xlm_roberta_large'] + + +class SelfAttention(nn.Module): + + def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.eps = eps + + # layers + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, mask): + """ + x: [B, L, C]. + """ + b, s, c, n, d = *x.size(), self.num_heads, self.head_dim + + # compute query, key, value + q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3) + k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3) + v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3) + + # compute attention + p = self.dropout.p if self.training else 0.0 + x = F.scaled_dot_product_attention(q, k, v, mask, p) + x = x.permute(0, 2, 1, 3).reshape(b, s, c) + + # output + x = self.o(x) + x = self.dropout(x) + return x + + +class AttentionBlock(nn.Module): + + def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.post_norm = post_norm + self.eps = eps + + # layers + self.attn = SelfAttention(dim, num_heads, dropout, eps) + self.norm1 = nn.LayerNorm(dim, eps=eps) + self.ffn = nn.Sequential( + nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim), + nn.Dropout(dropout)) + self.norm2 = nn.LayerNorm(dim, eps=eps) + + def forward(self, x, mask): + if self.post_norm: + x = self.norm1(x + self.attn(x, mask)) + x = self.norm2(x + self.ffn(x)) + else: + x = x + self.attn(self.norm1(x), mask) + x = x + self.ffn(self.norm2(x)) + return x + + +class XLMRoberta(nn.Module): + """ + XLMRobertaModel with no pooler and no LM head. + """ + + def __init__(self, + vocab_size=250002, + max_seq_len=514, + type_size=1, + pad_id=1, + dim=1024, + num_heads=16, + num_layers=24, + post_norm=True, + dropout=0.1, + eps=1e-5): + super().__init__() + self.vocab_size = vocab_size + self.max_seq_len = max_seq_len + self.type_size = type_size + self.pad_id = pad_id + self.dim = dim + self.num_heads = num_heads + self.num_layers = num_layers + self.post_norm = post_norm + self.eps = eps + + # embeddings + self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id) + self.type_embedding = nn.Embedding(type_size, dim) + self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id) + self.dropout = nn.Dropout(dropout) + + # blocks + self.blocks = nn.ModuleList([ + AttentionBlock(dim, num_heads, post_norm, dropout, eps) + for _ in range(num_layers) + ]) + + # norm layer + self.norm = nn.LayerNorm(dim, eps=eps) + + def forward(self, ids): + """ + ids: [B, L] of torch.LongTensor. + """ + b, s = ids.shape + mask = ids.ne(self.pad_id).long() + + # embeddings + x = self.token_embedding(ids) + \ + self.type_embedding(torch.zeros_like(ids)) + \ + self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask) + if self.post_norm: + x = self.norm(x) + x = self.dropout(x) + + # blocks + mask = torch.where( + mask.view(b, 1, 1, s).gt(0), 0.0, + torch.finfo(x.dtype).min) + for block in self.blocks: + x = block(x, mask) + + # output + if not self.post_norm: + x = self.norm(x) + return x + + +def xlm_roberta_large(pretrained=False, + return_tokenizer=False, + device='cpu', + **kwargs): + """ + XLMRobertaLarge adapted from Huggingface. + """ + # params + cfg = dict( + vocab_size=250002, + max_seq_len=514, + type_size=1, + pad_id=1, + dim=1024, + num_heads=16, + num_layers=24, + post_norm=True, + dropout=0.1, + eps=1e-5) + cfg.update(**kwargs) + + # init a model on device + with torch.device(device): + model = XLMRoberta(**cfg) + return model diff --git a/humo/utils/audio_processor_whisper.py b/humo/utils/audio_processor_whisper.py new file mode 100644 index 0000000000000000000000000000000000000000..557bf49c3abaa477405dbd6313c8b884a64ec4d0 --- /dev/null +++ b/humo/utils/audio_processor_whisper.py @@ -0,0 +1,173 @@ +# pylint: disable=C0301 +''' +This module contains the AudioProcessor class and related functions for processing audio data. +It utilizes various libraries and models to perform tasks such as preprocessing, feature extraction, +and audio separation. The class is initialized with configuration parameters and can process +audio files using the provided models. +''' +import os +import subprocess + +import librosa +import numpy as np +import torch +from audio_separator.separator import Separator +from transformers import WhisperModel, AutoFeatureExtractor +import torch.nn.functional as F + + +def linear_interpolation_fps(features, input_fps, output_fps, output_len=None): + features = features.transpose(1, 2) # [1, C, T] + seq_len = features.shape[2] / float(input_fps) + if output_len is None: + output_len = int(seq_len * output_fps) + output_features = F.interpolate(features, size=output_len, align_corners=True, mode='linear') + return output_features.transpose(1, 2) + + +def resample_audio(input_audio_file: str, output_audio_file: str, sample_rate: int): + p = subprocess.Popen([ + "ffmpeg", "-y", "-v", "error", "-i", input_audio_file, "-ar", str(sample_rate), output_audio_file + ]) + ret = p.wait() + assert ret == 0, "Resample audio failed!" + return output_audio_file + +class AudioProcessor: + """ + AudioProcessor is a class that handles the processing of audio files. + It takes care of preprocessing the audio files, extracting features + using wav2vec models, and separating audio signals if needed. + + :param sample_rate: Sampling rate of the audio file + :param fps: Frames per second for the extracted features + :param wav2vec_model_path: Path to the wav2vec model + :param only_last_features: Whether to only use the last features + :param audio_separator_model_path: Path to the audio separator model + :param audio_separator_model_name: Name of the audio separator model + :param cache_dir: Directory to cache the intermediate results + :param device: Device to run the processing on + """ + def __init__( + self, + sample_rate, + fps, + wav2vec_model_path, + wav2vec_feature_type, + audio_separator_model_path:str=None, + audio_separator_model_name:str=None, + cache_dir:str='', + device="cuda:0", + ) -> None: + self.sample_rate = sample_rate + self.fps = fps + self.device = device + + self.whisper = WhisperModel.from_pretrained(wav2vec_model_path).to(device).eval() + self.whisper.requires_grad_(False) + self.feature_extractor = AutoFeatureExtractor.from_pretrained(wav2vec_model_path) + + if audio_separator_model_name is not None: + try: + os.makedirs(cache_dir, exist_ok=True) + except OSError as _: + print("Fail to create the output cache dir.") + self.audio_separator = Separator( + output_dir=cache_dir, + output_single_stem="vocals", + model_file_dir=audio_separator_model_path, + ) + self.audio_separator.load_model(audio_separator_model_name) + assert self.audio_separator.model_instance is not None, "Fail to load audio separate model." + else: + self.audio_separator=None + print("Use audio directly without vocals seperator.") + + + def get_audio_feature(self, audio_path): + audio_input, sampling_rate = librosa.load(audio_path, sr=16000) + assert sampling_rate == 16000 + + audio_features = [] + window = 750*640 + for i in range(0, len(audio_input), window): + audio_feature = self.feature_extractor(audio_input[i:i+window], + sampling_rate=sampling_rate, + return_tensors="pt", + ).input_features + audio_features.append(audio_feature) + audio_features = torch.cat(audio_features, dim=-1) + return audio_features, len(audio_input) // 640 + + + def preprocess(self, audio_path: str): + audio_input, audio_len = self.get_audio_feature(audio_path) + audio_feature = audio_input.to(self.whisper.device).float() + window = 3000 + audio_prompts = [] + for i in range(0, audio_feature.shape[-1], window): + audio_prompt = self.whisper.encoder(audio_feature[:,:,i:i+window], output_hidden_states=True).hidden_states + audio_prompt = torch.stack(audio_prompt, dim=2) + audio_prompts.append(audio_prompt) + + audio_prompts = torch.cat(audio_prompts, dim=1) + audio_prompts = audio_prompts[:,:audio_len*2] + + audio_emb = self.audio_emb_enc(audio_prompts, wav_enc_type="whisper") + + return audio_emb, audio_emb.shape[0] + + def audio_emb_enc(self, audio_emb, wav_enc_type="whisper"): + if wav_enc_type == "wav2vec": + feat_merge = audio_emb + elif wav_enc_type == "whisper": + # [1, T, 33, 1280] + feat0 = linear_interpolation_fps(audio_emb[:, :, 0: 8].mean(dim=2), 50, 25) + feat1 = linear_interpolation_fps(audio_emb[:, :, 8: 16].mean(dim=2), 50, 25) + feat2 = linear_interpolation_fps(audio_emb[:, :, 16: 24].mean(dim=2), 50, 25) + feat3 = linear_interpolation_fps(audio_emb[:, :, 24: 32].mean(dim=2), 50, 25) + feat4 = linear_interpolation_fps(audio_emb[:, :, 32], 50, 25) + feat_merge = torch.stack([feat0, feat1, feat2, feat3, feat4], dim=2)[0] # [T, 5, 1280] + else: + raise ValueError(f"Unsupported wav_enc_type: {wav_enc_type}") + + return feat_merge + + def get_audio_emb_window(self, audio_emb, frame_num, frame0_idx, audio_shift=2): + zero_audio_embed = torch.zeros((audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device) + zero_audio_embed_3 = torch.zeros((3, audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device) # device=audio_emb.device + iter_ = 1 + (frame_num - 1) // 4 + audio_emb_wind = [] + for lt_i in range(iter_): + if lt_i == 0: # latent_i + # 提取第一帧VAElatent,audio左侧补0,标识出 + st = frame0_idx + lt_i - 2 + ed = frame0_idx + lt_i + 3 + wind_feat = torch.stack([ + audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed + for i in range(st, ed) + ], dim=0) # [5, 13, 768] + wind_feat = torch.cat((zero_audio_embed_3, wind_feat), dim=0) # [8, 13, 768] + else: + st = frame0_idx + 1 + 4 * (lt_i - 1) - audio_shift + ed = frame0_idx + 1 + 4 * lt_i + audio_shift + wind_feat = torch.stack([ + audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed + for i in range(st, ed) + ], dim=0) # [8, 13, 768] + audio_emb_wind.append(wind_feat) + audio_emb_wind = torch.stack(audio_emb_wind, dim=0) # [iter_, 8, 13, 768] + + return audio_emb_wind, ed - audio_shift + + def close(self): + """ + TODO: to be implemented + """ + return self + + def __enter__(self): + return self + + def __exit__(self, _exc_type, _exc_val, _exc_tb): + self.close() diff --git a/humo/utils/wav2vec.py b/humo/utils/wav2vec.py new file mode 100644 index 0000000000000000000000000000000000000000..7088391abfa1e5684bd36330b79ff6a4570c733b --- /dev/null +++ b/humo/utils/wav2vec.py @@ -0,0 +1,218 @@ +# pylint: disable=R0901 +# src/models/wav2vec.py + +""" +This module defines the Wav2Vec model, which is a pre-trained model for speech recognition and understanding. +It inherits from the Wav2Vec2Model class in the transformers library and provides additional functionalities +such as feature extraction and encoding. + +Classes: + Wav2VecModel: Inherits from Wav2Vec2Model and adds additional methods for feature extraction and encoding. + +Functions: + linear_interpolation: Interpolates the features based on the sequence length. +""" + +import torch.nn.functional as F +from transformers import Wav2Vec2Model +from transformers.modeling_outputs import BaseModelOutput + + +class Wav2VecModel(Wav2Vec2Model): + """ + Wav2VecModel is a custom model class that extends the Wav2Vec2Model class from the transformers library. + It inherits all the functionality of the Wav2Vec2Model and adds additional methods for feature extraction and encoding. + ... + + Attributes: + base_model (Wav2Vec2Model): The base Wav2Vec2Model object. + + Methods: + forward(input_values, seq_len, attention_mask=None, mask_time_indices=None + , output_attentions=None, output_hidden_states=None, return_dict=None): + Forward pass of the Wav2VecModel. + It takes input_values, seq_len, and other optional parameters as input and returns the output of the base model. + + feature_extract(input_values, seq_len): + Extracts features from the input_values using the base model. + + encode(extract_features, attention_mask=None, mask_time_indices=None, output_attentions=None, output_hidden_states=None, return_dict=None): + Encodes the extracted features using the base model and returns the encoded features. + """ + def forward( + self, + input_values, + seq_len, + attention_mask=None, + mask_time_indices=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + """ + Forward pass of the Wav2Vec model. + + Args: + self: The instance of the model. + input_values: The input values (waveform) to the model. + seq_len: The sequence length of the input values. + attention_mask: Attention mask to be used for the model. + mask_time_indices: Mask indices to be used for the model. + output_attentions: If set to True, returns attentions. + output_hidden_states: If set to True, returns hidden states. + return_dict: If set to True, returns a BaseModelOutput instead of a tuple. + + Returns: + The output of the Wav2Vec model. + """ + self.config.output_attentions = True + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + extract_features = self.feature_extractor(input_values) + extract_features = extract_features.transpose(1, 2) + extract_features = linear_interpolation(extract_features, seq_len=seq_len) + + if attention_mask is not None: + # compute reduced attention_mask corresponding to feature vectors + attention_mask = self._get_feature_vector_attention_mask( + extract_features.shape[1], attention_mask, add_adapter=False + ) + + hidden_states, extract_features = self.feature_projection(extract_features) + hidden_states = self._mask_hidden_states( + hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask + ) + + encoder_outputs = self.encoder( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = encoder_outputs[0] + + if self.adapter is not None: + hidden_states = self.adapter(hidden_states) + + if not return_dict: + return (hidden_states, ) + encoder_outputs[1:] + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + + def feature_extract( + self, + input_values, + seq_len, + ): + """ + Extracts features from the input values and returns the extracted features. + + Parameters: + input_values (torch.Tensor): The input values to be processed. + seq_len (torch.Tensor): The sequence lengths of the input values. + + Returns: + extracted_features (torch.Tensor): The extracted features from the input values. + """ + extract_features = self.feature_extractor(input_values) + extract_features = extract_features.transpose(1, 2) + extract_features = linear_interpolation(extract_features, seq_len=seq_len) + + return extract_features + + def encode( + self, + extract_features, + attention_mask=None, + mask_time_indices=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + """ + Encodes the input features into the output space. + + Args: + extract_features (torch.Tensor): The extracted features from the audio signal. + attention_mask (torch.Tensor, optional): Attention mask to be used for padding. + mask_time_indices (torch.Tensor, optional): Masked indices for the time dimension. + output_attentions (bool, optional): If set to True, returns the attention weights. + output_hidden_states (bool, optional): If set to True, returns all hidden states. + return_dict (bool, optional): If set to True, returns a BaseModelOutput instead of the tuple. + + Returns: + The encoded output features. + """ + self.config.output_attentions = True + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if attention_mask is not None: + # compute reduced attention_mask corresponding to feature vectors + attention_mask = self._get_feature_vector_attention_mask( + extract_features.shape[1], attention_mask, add_adapter=False + ) + + hidden_states, extract_features = self.feature_projection(extract_features) + hidden_states = self._mask_hidden_states( + hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask + ) + + encoder_outputs = self.encoder( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = encoder_outputs[0] + + if self.adapter is not None: + hidden_states = self.adapter(hidden_states) + + if not return_dict: + return (hidden_states, ) + encoder_outputs[1:] + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +def linear_interpolation(features, seq_len): + """ + Transpose the features to interpolate linearly. + + Args: + features (torch.Tensor): The extracted features to be interpolated. + seq_len (torch.Tensor): The sequence lengths of the features. + + Returns: + torch.Tensor: The interpolated features. + """ + features = features.transpose(1, 2) + output_features = F.interpolate(features, size=seq_len, align_corners=True, mode='linear') + return output_features.transpose(1, 2) + + +def linear_interpolation_fps(features, input_fps, output_fps, output_len=None): + features = features.transpose(1, 2) # [1, C, T] + seq_len = features.shape[2] / float(input_fps) + if output_len is None: + output_len = int(seq_len * output_fps) + output_features = F.interpolate(features, size=output_len, align_corners=True, mode='linear') + return output_features.transpose(1, 2) \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000000000000000000000000000000000000..e06a4e85c6290fb63503decc5f3e8a724e836a81 --- /dev/null +++ b/main.py @@ -0,0 +1,28 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Inference codes adapted from [SeedVR] +# https://github.com/ByteDance-Seed/SeedVR/blob/main/projects/inference_seedvr2_7b.py + +from sys import argv +import sys + +path_to_insert = "humo" +if path_to_insert not in sys.path: + sys.path.insert(0, path_to_insert) + +from common.config import load_config, create_object + +# Load config. +config = load_config(argv[1], argv[2:]) + +runner = create_object(config) +runner.entrypoint() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..e4705239063aed5fc9e42fba911182af3b81e09c --- /dev/null +++ b/requirements.txt @@ -0,0 +1,12 @@ +mediapy +diffusers +transformers +torchao +imageio +numpy==1.26.4 +ftfy +audio-separator==0.24.1 +onnxruntime +omegaconf +moviepy==1.0.3 +kernels diff --git a/scripts/infer_ta.sh b/scripts/infer_ta.sh new file mode 100644 index 0000000000000000000000000000000000000000..89be8106aaa2e23c7fc8ac1a57b2ff3da50c1541 --- /dev/null +++ b/scripts/infer_ta.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +torchrun --node_rank=0 --nproc_per_node=8 --nnodes=1 \ + --rdzv_endpoint=127.0.0.1:12345 \ + --rdzv_conf=timeout=900,join_timeout=900,read_timeout=900 \ + main.py humo/configs/inference/generate.yaml \ + generation.frames=97 \ + generation.scale_a=5.5 \ + generation.scale_t=5.0 \ + generation.mode=TA \ + generation.height=720 \ + generation.width=1280 \ + diffusion.timesteps.sampling.steps=50 \ + generation.positive_prompt=./examples/test_case.json \ + generation.output.dir=./output diff --git a/scripts/infer_ta_1_7B.sh b/scripts/infer_ta_1_7B.sh new file mode 100644 index 0000000000000000000000000000000000000000..d87331640b06800aa4fbb2a22a87620656010629 --- /dev/null +++ b/scripts/infer_ta_1_7B.sh @@ -0,0 +1,49 @@ +#!/bin/bash + +# Run on a single GPU +CUDA_VISIBLE_DEVICES=0 torchrun --node_rank=0 --nproc_per_node=1 --nnodes=1 \ + --rdzv_endpoint=127.0.0.1:12345 \ + --rdzv_conf=timeout=900,join_timeout=900,read_timeout=900 \ + main.py humo/configs/inference/generate_1_7B.yaml \ + dit.sp_size=1 \ + generation.frames=97 \ + generation.scale_t=7.0 \ + generation.scale_a=7.5 \ + generation.mode=TA \ + generation.height=480 \ + generation.width=832 \ + diffusion.timesteps.sampling.steps=50 \ + generation.positive_prompt=./examples/test_case.json \ + generation.output.dir=./output + +# # Run on 2 GPUs +# CUDA_VISIBLE_DEVICES=0,1 torchrun --node_rank=0 --nproc_per_node=2 --nnodes=1 \ +# --rdzv_endpoint=127.0.0.1:12345 \ +# --rdzv_conf=timeout=900,join_timeout=900,read_timeout=900 \ +# main.py humo/configs/inference/generate_1_7B.yaml \ +# dit.sp_size=2 \ +# generation.frames=97 \ +# generation.scale_t=7.0 \ +# generation.scale_a=7.5 \ +# generation.mode=TA \ +# generation.height=480 \ +# generation.width=832 \ +# diffusion.timesteps.sampling.steps=50 \ +# generation.positive_prompt=./examples/test_case.json \ +# generation.output.dir=./output + +# # Run on 4 GPUs +# CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --node_rank=0 --nproc_per_node=4 --nnodes=1 \ +# --rdzv_endpoint=127.0.0.1:12345 \ +# --rdzv_conf=timeout=900,join_timeout=900,read_timeout=900 \ +# main.py humo/configs/inference/generate_1_7B.yaml \ +# dit.sp_size=4 \ +# generation.frames=97 \ +# generation.scale_t=7.0 \ +# generation.scale_a=7.5 \ +# generation.mode=TA \ +# generation.height=480 \ +# generation.width=832 \ +# diffusion.timesteps.sampling.steps=50 \ +# generation.positive_prompt=./examples/test_case.json \ +# generation.output.dir=./output \ No newline at end of file diff --git a/scripts/infer_tia.sh b/scripts/infer_tia.sh new file mode 100644 index 0000000000000000000000000000000000000000..e628342db1058ebcb628ab61742bad878e747314 --- /dev/null +++ b/scripts/infer_tia.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +torchrun --node_rank=0 --nproc_per_node=8 --nnodes=1 \ + --rdzv_endpoint=127.0.0.1:12345 \ + --rdzv_conf=timeout=900,join_timeout=900,read_timeout=900 \ + main.py humo/configs/inference/generate.yaml \ + generation.frames=97 \ + generation.scale_a=5.5 \ + generation.scale_t=5.0 \ + generation.mode=TIA \ + generation.height=720 \ + generation.width=1280 \ + diffusion.timesteps.sampling.steps=50 \ + generation.positive_prompt=./examples/test_case.json \ + generation.output.dir=./output diff --git a/scripts/infer_tia_1_7B.sh b/scripts/infer_tia_1_7B.sh new file mode 100644 index 0000000000000000000000000000000000000000..45e5e5ae9a12fc062fbb0d585c741eaa87f5b856 --- /dev/null +++ b/scripts/infer_tia_1_7B.sh @@ -0,0 +1,52 @@ +#!/bin/bash + +# Run on a single GPU +CUDA_VISIBLE_DEVICES=0 torchrun --node_rank=0 --nproc_per_node=1 --nnodes=1 \ + --rdzv_endpoint=127.0.0.1:12345 \ + --rdzv_conf=timeout=900,join_timeout=900,read_timeout=900 \ + main.py humo/configs/inference/generate_1_7B.yaml \ + dit.sp_size=1 \ + generation.frames=97 \ + generation.scale_t=7.0 \ + generation.scale_i=4.0 \ + generation.scale_a=7.5 \ + generation.mode=TIA \ + generation.height=480 \ + generation.width=832 \ + diffusion.timesteps.sampling.steps=50 \ + generation.positive_prompt=./examples/test_case.json \ + generation.output.dir=./output + +# # Run on 2 GPUs +# CUDA_VISIBLE_DEVICES=0,1 torchrun --node_rank=0 --nproc_per_node=2 --nnodes=1 \ +# --rdzv_endpoint=127.0.0.1:12345 \ +# --rdzv_conf=timeout=900,join_timeout=900,read_timeout=900 \ +# main.py humo/configs/inference/generate_1_7B.yaml \ +# dit.sp_size=2 \ +# generation.frames=97 \ +# generation.scale_t=7.0 \ +# generation.scale_i=4.0 \ +# generation.scale_a=7.5 \ +# generation.mode=TIA \ +# generation.height=480 \ +# generation.width=832 \ +# diffusion.timesteps.sampling.steps=50 \ +# generation.positive_prompt=./examples/test_case.json \ +# generation.output.dir=./output + +# # Run on 4 GPUs +# CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --node_rank=0 --nproc_per_node=4 --nnodes=1 \ +# --rdzv_endpoint=127.0.0.1:12345 \ +# --rdzv_conf=timeout=900,join_timeout=900,read_timeout=900 \ +# main.py humo/configs/inference/generate_1_7B.yaml \ +# dit.sp_size=4 \ +# generation.frames=97 \ +# generation.scale_t=7.0 \ +# generation.scale_i=4.0 \ +# generation.scale_a=7.5 \ +# generation.mode=TIA \ +# generation.height=480 \ +# generation.width=832 \ +# diffusion.timesteps.sampling.steps=50 \ +# generation.positive_prompt=./examples/test_case.json \ +# generation.output.dir=./output