RaphaelLiu commited on
Commit
759dfe0
·
verified ·
1 Parent(s): ce6b04d

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. LICENSE +201 -0
  3. README.md +78 -3
  4. assets/grid.mp4 +3 -0
  5. assets/mochi-factory.webp +3 -0
  6. contrib/README.md +6 -0
  7. contrib/modal/lora.yaml +58 -0
  8. contrib/modal/main.py +285 -0
  9. contrib/modal/readme.md +55 -0
  10. decoder.safetensors +3 -0
  11. demos/api_example.py +53 -0
  12. demos/cli.py +163 -0
  13. demos/comfyui_nodes.py +0 -0
  14. demos/fine_tuner/README.md +99 -0
  15. demos/fine_tuner/configs/lora.yaml +58 -0
  16. demos/fine_tuner/dataset.py +45 -0
  17. demos/fine_tuner/embed_captions.py +66 -0
  18. demos/fine_tuner/encode_videos.py +142 -0
  19. demos/fine_tuner/preprocess.bash +64 -0
  20. demos/fine_tuner/run.bash +92 -0
  21. demos/fine_tuner/train.py +396 -0
  22. demos/fine_tuner/trim_and_crop_videos.py +110 -0
  23. demos/gradio_ui.py +57 -0
  24. demos/test_encoder_decoder.py +79 -0
  25. encoder.safetensors +3 -0
  26. model_index.json +24 -0
  27. pusa_v0_dit.safetensors +3 -0
  28. pyproject.toml +37 -0
  29. pyrightconfig.json +4 -0
  30. requirements.txt +14 -0
  31. scheduler/scheduler_config.json +12 -0
  32. scripts/download_weights.py +41 -0
  33. scripts/format.bash +5 -0
  34. scripts/pytorch_to_safe_tensors.py +24 -0
  35. scripts/typecheck.bash +2 -0
  36. scripts/weights_to_fp8.py +0 -0
  37. src/genmo/lib/attn_imports.py +29 -0
  38. src/genmo/lib/progress.py +87 -0
  39. src/genmo/lib/utils.py +67 -0
  40. src/genmo/mochi_preview/__init__.py +0 -0
  41. src/genmo/mochi_preview/dit/joint_model/__init__.py +0 -0
  42. src/genmo/mochi_preview/dit/joint_model/asymm_models_joint.py +737 -0
  43. src/genmo/mochi_preview/dit/joint_model/context_parallel.py +158 -0
  44. src/genmo/mochi_preview/dit/joint_model/layers.py +179 -0
  45. src/genmo/mochi_preview/dit/joint_model/lora.py +112 -0
  46. src/genmo/mochi_preview/dit/joint_model/mod_rmsnorm.py +15 -0
  47. src/genmo/mochi_preview/dit/joint_model/residual_tanh_gated_rmsnorm.py +20 -0
  48. src/genmo/mochi_preview/dit/joint_model/rope_mixed.py +88 -0
  49. src/genmo/mochi_preview/dit/joint_model/temporal_rope.py +34 -0
  50. src/genmo/mochi_preview/dit/joint_model/utils.py +109 -0
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/grid.gif filter=lfs diff=lfs merge=lfs -text
37
+ assets/grid.mp4 filter=lfs diff=lfs merge=lfs -text
38
+ assets/mochi-factory.webp filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2024 Genmo
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,3 +1,78 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Pusa VidGen
3
+
4
+ [Codes](https://github.com/Yaofang-Liu/Pusa-VidGen) | [Hugging Face](https://huggingface.co/RaphaelLiu/Pusa-V0.5)
5
+
6
+ ## Overview
7
+
8
+ Pusa is an advanced open-source video generation model that builds upon Mochi 1 with significant enhancements. It supports multiple video generation tasks while maintaining high-fidelity motion and strong prompt adherence. The model is released under a permissive Apache 2.0 license.
9
+
10
+ ✨ **Key Features**
11
+ - **Multi-task support**: Text-to-Video, Image-to-Video, Interpolation, Transition, Loop, Long Video, and more
12
+ - **Cost-efficient**: Trained with just 100 H100 GPU hours
13
+ - **Full Open-Source**: Code, architecture, and training details included
14
+
15
+ 🔍 **Unique Architecture**
16
+ - A novel diffusion model supporting frame-level noise with vectorized timesteps originally introduced in the [FVDM paper](https://arxiv.org/abs/2410.03160) for flexibility and scalability
17
+
18
+ ## Installation
19
+
20
+ Install using [uv](https://github.com/astral-sh/uv):
21
+
22
+ ```bash
23
+ git clone https://github.com/Yaofang-Liu/Pusa-VidGen
24
+ cd models
25
+ pip install uv
26
+ uv venv .venv
27
+ source .venv/bin/activate
28
+ uv pip install setuptools
29
+ uv pip install -e . --no-build-isolation
30
+ ```
31
+
32
+
33
+ If you want to install flash attention, you can use:
34
+ ```
35
+ uv pip install -e .[flash] --no-build-isolation
36
+ ```
37
+
38
+ You will also need to install [FFMPEG](https://www.ffmpeg.org/) to turn your outputs into videos.
39
+
40
+ ## Download Weights
41
+
42
+ You can use the Hugging Face CLI to download the model:
43
+ ```
44
+ pip install huggingface_hub
45
+ huggingface-cli download RaphaelLiu/Pusa-V0.5 --local-dir <path_to_downloaded_directory>
46
+
47
+ ```
48
+ Or, directly download the weights from [Hugging Face](https://huggingface.co/RaphaelLiu/Pusa-V0.5) to a folder on your computer.
49
+
50
+
51
+ ## Limitations
52
+ Pusa has a few known limitations. The base model Mochi generates videos at 480p. We expect to get better results when use our proposed method to more powerful models like Wan2.1. We also welcom collobartion from the community to improve the model and extend its capabilities.
53
+
54
+ ## Related Work
55
+ - [mochi](https://huggingface.co/genmo/mochi-1-preview) is our base model, top 3 open-source video generation models in Artifical Analysis Leaderboard for video generation.
56
+ - [FVDM](https://arxiv.org/abs/2410.03160) introduces the vectorized timestep approach that inspired Pusa's frame-level noise control.
57
+
58
+ ## BibTeX
59
+ ```
60
+ @misc{Liu2025pusa,
61
+ title={Pusa: A Next-Level All-in-One Video Diffusion Model},
62
+ author={Yaofang Liu and Rui Liu},
63
+ year={2025},
64
+ publisher = {GitHub},
65
+ journal = {GitHub repository},
66
+ howpublished={\url{https://github.com/Yaofang-Liu/Pusa-VidGen}}
67
+ }
68
+ ```
69
+
70
+ ```
71
+ @article{liu2024redefining,
72
+ title={Redefining Temporal Modeling in Video Diffusion: The Vectorized Timestep Approach},
73
+ author={Liu, Yaofang and Ren, Yumeng and Cun, Xiaodong and Artola, Aitor and Liu, Yang and Zeng, Tieyong and Chan, Raymond H and Morel, Jean-michel},
74
+ journal={arXiv preprint arXiv:2410.03160},
75
+ year={2024}
76
+ }
77
+ ```
78
+
assets/grid.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e10304cf30b68f92b82438a33d4dd1dade8169e142fee538ceebb9d91565604d
3
+ size 6905424
assets/mochi-factory.webp ADDED

Git LFS Details

  • SHA256: dd70d39a9a26d7e69c9264caa947da0ae5c3695c384529ce469ecd1703abd165
  • Pointer size: 131 Bytes
  • Size of remote file: 560 kB
contrib/README.md ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Mochi Community Contributions
2
+
3
+ `mochi/contrib` contains community contributed pipelines for running and customizing Mochi.
4
+
5
+ ## Index:
6
+ - `mochi/contrib/modal` - [Script](contrib/modal/readme.md) for fine-tuning Mochi on Modal GPUs.
contrib/modal/lora.yaml ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ init_checkpoint_path: /weights/dit.safetensors
2
+ checkpoint_dir: /finetunes/my_mochi_lora
3
+ train_data_dir: /videos_prepared
4
+ attention_mode: sdpa
5
+ single_video_mode: false # Useful for debugging whether your model can learn a single video
6
+
7
+ # You only need this if you're using wandb
8
+ wandb:
9
+ # project: mochi_1_lora
10
+ # name: ${checkpoint_dir}
11
+ # group: null
12
+
13
+ optimizer:
14
+ lr: 2e-4
15
+ weight_decay: 0.01
16
+
17
+ model:
18
+ type: lora
19
+ kwargs:
20
+ # Apply LoRA to the QKV projection and the output projection of the attention block.
21
+ qkv_proj_lora_rank: 16
22
+ qkv_proj_lora_alpha: 16
23
+ qkv_proj_lora_dropout: 0.
24
+ out_proj_lora_rank: 16
25
+ out_proj_lora_alpha: 16
26
+ out_proj_lora_dropout: 0.
27
+
28
+ training:
29
+ model_dtype: bf16
30
+ warmup_steps: 200
31
+ num_qkv_checkpoint: 48
32
+ num_ff_checkpoint: 48
33
+ num_post_attn_checkpoint: 48
34
+ num_steps: 2000
35
+ save_interval: 200
36
+ caption_dropout: 0.1
37
+ grad_clip: 0.0
38
+ save_safetensors: true
39
+
40
+ # Used for generating samples during training to monitor progress ...
41
+ sample:
42
+ interval: 200
43
+ output_dir: ${checkpoint_dir}/samples
44
+ decoder_path: /weights/decoder.safetensors
45
+ prompts:
46
+ - A pristine snowglobe featuring a winter scene sits peacefully. The glass begins to crumble into fine powder, as the entire sphere deteriorates into sparkling dust that drifts outward. The fake snow mingles with the crystalline particles, creating a glittering cloud captured in high-speed photography.
47
+ - A vintage pocket watch ticks quietly on an antique desk. Its brass casing starts to deteriorate, turning to fine metallic powder that lifts into the air. The gears and springs fragment into microscopic particles, each piece breaking down into a shimmering bronze dust that hangs suspended. The scene is richly detailed with warm, brass tones.
48
+ - A cello is propped up against a wall, a single spotlight illuminating it. The wooden surface begins to decay into fine sawdust, the instrument gradually breaking apart as its form disintegrates into a cloud of earthen particles. The strings unravel into delicate fibers that float amidst the swirling wooden dust. The scene is vibrant and colorful.
49
+ - A graphics card sits inside an oven, heatwaves around it. The silicon and metal components begin to break down at a molecular level, deteriorating into a dark cloud of fine metallic and mineral dust that hangs suspended in the heated air. The scene is darkly lit, high contrast, with a focus on the suspended particles.
50
+ - A delicate porcelain teacup sits on a marble countertop. The ceramic structure begins to crumble into a fine, chalk-like powder, breaking down into countless microscopic white particles that drift upward in graceful patterns. The scene is bright and crisp with dramatic lighting illuminating the cloud of porcelain dust.
51
+ seed: 12345
52
+ kwargs:
53
+ height: 480
54
+ width: 848
55
+ num_frames: 37
56
+ num_inference_steps: 64
57
+ sigma_schedule_python_code: "linear_quadratic_schedule(64, 0.025)"
58
+ cfg_schedule_python_code: "[6.0] * 64"
contrib/modal/main.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import modal
2
+ from pathlib import Path
3
+
4
+ # Creating our Modal App
5
+ app = modal.App("mochi-finetune")
6
+
7
+ # Creating volumes for data, intermediate data, and produced weights
8
+ videos_volume = modal.Volume.from_name("mochi-tune-videos", create_if_missing=True)
9
+ videos_prepared_volume = modal.Volume.from_name("mochi-tune-videos-prepared", create_if_missing=True)
10
+ weights_volume = modal.Volume.from_name("mochi-tune-weights", create_if_missing=True)
11
+ finetunes_volume = modal.Volume.from_name("mochi-tune-finetunes", create_if_missing=True)
12
+ outputs_volume = modal.Volume.from_name("mochi-tune-outputs", create_if_missing=True)
13
+
14
+ USERNAME = "genmoai"
15
+ REPOSITORY = "mochi"
16
+ CLONE_CMD = f"git clone https://github.com/{USERNAME}/{REPOSITORY}.git"
17
+
18
+ # Building our container image
19
+ base_img = (
20
+ modal.Image.debian_slim()
21
+ .apt_install("git", "ffmpeg", "bc", "zlib1g-dev", "libjpeg-dev", "wget")
22
+ .run_commands(CLONE_CMD)
23
+ .workdir(REPOSITORY)
24
+ .pip_install("gdown", "setuptools", "wheel")
25
+ .run_commands('pip install -e . --no-build-isolation')
26
+ )
27
+
28
+ MINUTES = 60
29
+ HOURS = 60 * MINUTES
30
+
31
+ # Remote function for downloading a labeled video dataset from Google Drive
32
+ # Run it with:
33
+ # modal run main::download_videos
34
+ @app.function(image=base_img,
35
+ volumes={
36
+ "/videos": videos_volume,
37
+ }
38
+ )
39
+ def download_videos():
40
+ '''Downloads videos from google drive into our volume'''
41
+ import gdown
42
+ import zipfile
43
+
44
+ name = "dissolve"
45
+ url = "https://drive.google.com/uc?id=1ldoBppcsv5Ueoikh0zCmNviojRCrGXQN"
46
+ output = f"{name}.zip"
47
+ gdown.download(url, output, quiet=False)
48
+ with zipfile.ZipFile(output, "r") as zip_ref:
49
+ zip_ref.extractall("/videos")
50
+
51
+ # Remote function for downloading the model weights from Hugging Face
52
+ # Run it with:
53
+ # modal run main::download_weights
54
+ @app.function(image=base_img,
55
+ volumes={
56
+ "/weights": weights_volume,
57
+ },
58
+ timeout=1*HOURS,
59
+ )
60
+ def download_weights():
61
+ # HF-transfer and snapshot download tend to hang on the large model, so we download it manually with wget
62
+ import subprocess
63
+ print("🍡 Downloading weights from Hugging Face. This may take 30 minutes.")
64
+ # ~30 min
65
+ subprocess.run(["wget", "https://huggingface.co/genmo/mochi-1-preview/resolve/main/dit.safetensors", "-O", "/weights/dit.safetensors"])
66
+ # ~1 min
67
+ subprocess.run(["wget", "https://huggingface.co/genmo/mochi-1-preview/resolve/main/decoder.safetensors", "-O", "/weights/decoder.safetensors"])
68
+ # ~20 sec
69
+ subprocess.run(["wget", "https://huggingface.co/genmo/mochi-1-preview/resolve/main/encoder.safetensors", "-O", "/weights/encoder.safetensors"])
70
+
71
+ # Remote function for preprocessing the video dataset
72
+ # Run it with:
73
+ # modal run main::preprocess
74
+ @app.function(
75
+ image=base_img,
76
+ volumes={
77
+ "/videos": videos_volume,
78
+ "/videos_prepared": videos_prepared_volume,
79
+ "/weights": weights_volume,
80
+ },
81
+ timeout=30*MINUTES,
82
+ gpu="H100"
83
+ )
84
+ def preprocess():
85
+ import subprocess
86
+ print("🍡 Preprocessing videos. This may take 2-3 minutes.")
87
+ video_dir = "videos_dissolve"
88
+ subprocess.run([
89
+ "bash", "demos/fine_tuner/preprocess.bash",
90
+ "-v", f"/videos/{video_dir}/",
91
+ "-o", "/videos_prepared/",
92
+ "-w", "/weights/",
93
+ "-n", "37"
94
+ ])
95
+
96
+ # Remote function for finetuning the model using the prepared dataset
97
+ # Configure the run in lora.yaml
98
+ # Run it with:
99
+ # modal run main::finetune
100
+ @app.function(
101
+ image=base_img,
102
+ volumes={
103
+ "/videos": videos_volume,
104
+ "/videos_prepared": videos_prepared_volume,
105
+ "/weights": weights_volume,
106
+ "/finetunes": finetunes_volume,
107
+ },
108
+ mounts=[modal.Mount.from_local_file("lora.yaml", remote_path=f"{REPOSITORY}/lora.yaml")],
109
+ timeout=4*HOURS,
110
+ gpu="H100"
111
+ )
112
+ def finetune():
113
+ import subprocess
114
+ print("🍡 Finetuning Mochi. This may take 3 hours.")
115
+ print("🍡 See your mochi-tune-finetunes volume for intermediate checkpoints and samples.")
116
+ subprocess.run([
117
+ "bash", "demos/fine_tuner/run.bash",
118
+ "-c", "lora.yaml", # from our locally mounted yaml file
119
+ "-n", "1",
120
+ ])
121
+
122
+ # Remote function (Modal @cls) for running inference on one or multiple videos
123
+ # Run it with the @local_entrypoint below
124
+ @app.cls(
125
+ image = base_img,
126
+ volumes={
127
+ "/weights": weights_volume,
128
+ "/finetunes": finetunes_volume,
129
+ "/outputs": outputs_volume,
130
+ },
131
+ timeout=30*MINUTES,
132
+ gpu="H100"
133
+ )
134
+ class MochiLora():
135
+ def __init__(self, model_dir: str = "/weights", lora_path: str = None, cpu_offload: bool = False):
136
+ self.model_dir = model_dir
137
+ self.lora_path = lora_path
138
+ self.cpu_offload = cpu_offload
139
+
140
+ @modal.enter()
141
+ def start(self):
142
+ from genmo.mochi_preview.pipelines import (
143
+ DecoderModelFactory,
144
+ DitModelFactory,
145
+ MochiMultiGPUPipeline,
146
+ MochiSingleGPUPipeline,
147
+ T5ModelFactory,
148
+ )
149
+ import torch
150
+
151
+ """Initialize the model - this runs once when the container starts"""
152
+ print("🍡 Loading Mochi model.")
153
+
154
+ self.num_gpus = torch.cuda.device_count()
155
+
156
+ # Configure pipeline based on GPU count
157
+ klass = MochiSingleGPUPipeline if self.num_gpus == 1 else MochiMultiGPUPipeline
158
+
159
+ kwargs = dict(
160
+ text_encoder_factory=T5ModelFactory(),
161
+ dit_factory=DitModelFactory(
162
+ model_path=f"{self.model_dir}/dit.safetensors",
163
+ lora_path=self.lora_path,
164
+ model_dtype="bf16",
165
+ ),
166
+ decoder_factory=DecoderModelFactory(
167
+ model_path=f"{self.model_dir}/decoder.safetensors",
168
+ ),
169
+ )
170
+
171
+ if self.num_gpus > 1:
172
+ assert not self.lora_path, f"Lora not supported in multi-GPU mode"
173
+ assert not self.cpu_offload, "CPU offload not supported in multi-GPU mode"
174
+ kwargs["world_size"] = self.num_gpus
175
+ else:
176
+ kwargs["cpu_offload"] = self.cpu_offload
177
+ kwargs["decode_type"] = "tiled_spatial"
178
+ kwargs["fast_init"] = not self.lora_path
179
+ kwargs["strict_load"] = not self.lora_path
180
+ kwargs["decode_args"] = dict(overlap=8)
181
+
182
+ self.pipeline = klass(**kwargs)
183
+ print(f"🍡 Model loaded successfully with {self.num_gpus} GPUs")
184
+
185
+ @modal.method()
186
+ def generate(self,
187
+ prompt: str,
188
+ negative_prompt: str = "",
189
+ width: int = 848,
190
+ height: int = 480,
191
+ num_frames: int = 163,
192
+ seed: int = 1710977262,
193
+ cfg_scale: float = 6.0,
194
+ num_inference_steps: int = 64) -> str:
195
+ """Generate video based on the prompt and parameters"""
196
+
197
+ print("🍡 Generating video.")
198
+
199
+ import json
200
+ import os
201
+ import time
202
+
203
+ import numpy as np
204
+
205
+ from genmo.lib.progress import progress_bar
206
+ from genmo.lib.utils import save_video
207
+ from genmo.mochi_preview.pipelines import linear_quadratic_schedule
208
+
209
+
210
+ # Create sigma schedule
211
+ sigma_schedule = linear_quadratic_schedule(num_inference_steps, 0.025)
212
+ cfg_schedule = [cfg_scale] * num_inference_steps
213
+
214
+ args = {
215
+ "height": height,
216
+ "width": width,
217
+ "num_frames": num_frames,
218
+ "sigma_schedule": sigma_schedule,
219
+ "cfg_schedule": cfg_schedule,
220
+ "num_inference_steps": num_inference_steps,
221
+ "batch_cfg": False,
222
+ "prompt": prompt,
223
+ "negative_prompt": negative_prompt,
224
+ "seed": seed,
225
+ }
226
+
227
+ with progress_bar(type="tqdm"):
228
+ final_frames = self.pipeline(**args)
229
+ final_frames = final_frames[0]
230
+
231
+ assert isinstance(final_frames, np.ndarray)
232
+ assert final_frames.dtype == np.float32
233
+
234
+ # Save to mounted volume
235
+ output_dir = "/outputs" # Assuming this path exists in the mounted volume
236
+ os.makedirs(output_dir, exist_ok=True)
237
+ output_path = os.path.join(output_dir, f"output_{int(time.time())}.mp4")
238
+
239
+ save_video(final_frames, output_path)
240
+
241
+ # Save generation parameters
242
+ json_path = os.path.splitext(output_path)[0] + ".json"
243
+ json.dump(args, open(json_path, "w"), indent=4)
244
+
245
+ print(f"🍡 Video saved to {output_path}")
246
+ outputs_volume.commit()
247
+ return output_path.split("/")[-1]
248
+
249
+ # Local entrypoint for using the MochiLora class
250
+ # Select the lora_path you'd want to use from the finetunes volume
251
+ # Then it with:
252
+ # modal run main
253
+ @app.local_entrypoint()
254
+ def main(
255
+ prompt="A pristine snowglobe featuring a winter scene sits peacefully. The glass begins to crumble into fine powder, as the entire sphere deteriorates into sparkling dust that drifts outward. The fake snow mingles with the crystalline particles, creating a glittering cloud captured in high-speed photography.",
256
+ negative_prompt="blurry, low quality",
257
+ width=848,
258
+ height=480,
259
+ num_frames=49, # (num_frames - 1) must be divisible by 6
260
+ seed=1710977262,
261
+ cfg_scale=6.0,
262
+ num_inference_steps=64,
263
+ lora_path="/finetunes/my_mochi_lora/model_2000.lora.safetensors",
264
+ cpu_offload=True,
265
+ ):
266
+ lora = MochiLora(
267
+ lora_path=lora_path, # your lora path
268
+ cpu_offload=cpu_offload,
269
+ )
270
+ output_path = lora.generate.remote(
271
+ prompt=prompt,
272
+ negative_prompt=negative_prompt,
273
+ width=width,
274
+ height=height,
275
+ num_frames=num_frames,
276
+ seed=seed,
277
+ cfg_scale=cfg_scale,
278
+ num_inference_steps=num_inference_steps,
279
+ )
280
+
281
+ local_dir = Path("/tmp/mochi")
282
+ local_dir.mkdir(exist_ok=True, parents=True)
283
+ local_path = local_dir / output_path
284
+ local_path.write_bytes(b"".join(outputs_volume.read_file(output_path)))
285
+ print(f"🍡 video saved locally at {local_path}")
contrib/modal/readme.md ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Finetuning Mochi with LoRA on Modal
2
+
3
+ This example demonstrates how to run the Mochi finetuner on Modal GPUs.
4
+
5
+ ### Setup
6
+ Install [Modal](https://modal.com/docs/guide).
7
+ ```bash
8
+ pip install modal
9
+ modal setup
10
+ ```
11
+
12
+ ### Fetch the dataset
13
+ There is a labeled dataset for a dissolving visual effect available on Google Drive. Download it into the `mochi-tune-videos` modal volume with:
14
+ ```bash
15
+ modal run main::download_videos
16
+ ```
17
+
18
+ ### Download the model weights
19
+ Download the model weights from Hugging Face into the `mochi-tune-weights` modal volume with:
20
+ ```bash
21
+ modal run -d main::download_weights
22
+ ```
23
+ Note that this download can take more than 30 minutes. The `-d` flag allows you to exit the terminal session without losing progress.
24
+
25
+ ### Prepare the dataset
26
+ We now run the preprocessing script to prepare the dataset for finetuning:
27
+ ```bash
28
+ modal run main::preprocess
29
+ ```
30
+ This puts preprocessed training input into the `mochi-tune-videos-prepared` modal volume.
31
+
32
+ ### Finetuning
33
+ Finetune the model using the prepared dataset.
34
+
35
+ You may configure the finetune run using the `lora.yaml` file, such as number of steps, learning rate, etc.
36
+
37
+ Run the finetuning with:
38
+ ```bash
39
+ modal run -d main::finetune
40
+ ```
41
+
42
+ This will produce a series of checkpoints, as well as video samples generated along the training process. You can view these files in the Modal `moshi-tune-finetunes` volume using the Storage tab in the dashboard.
43
+
44
+ ### Inference
45
+ You can now use the MochiLora class to generate videos from a prompt. The `main` entrypoint will initialize the model to use the specified LoRA weights from your finetuning run.
46
+
47
+ ```bash
48
+ modal run main
49
+ ```
50
+ or with more parameters:
51
+ ```bash
52
+ modal run main lora-path="/finetunes/my_mochi_lora/model_1000.lora.safetensors" prompt="A pristine snowglobe featuring a winter scene sits peacefully. The glass begins to crumble into fine powder, as the entire sphere deteriorates into sparkling dust that drifts outward."
53
+ ```
54
+
55
+ See modal run main --help for all inference options.
decoder.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:641920faaf20e5404ddb5553ce3e295c21ed9b4bc5f6fe7c930811b84099cb14
3
+ size 1450122828
demos/api_example.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/env python
2
+ import sys
3
+ from pathlib import Path
4
+ from textwrap import dedent
5
+
6
+ from genmo.lib.progress import progress_bar
7
+ from genmo.lib.utils import save_video
8
+ from genmo.mochi_preview.pipelines import (
9
+ DecoderModelFactory,
10
+ DitModelFactory,
11
+ MochiSingleGPUPipeline,
12
+ T5ModelFactory,
13
+ linear_quadratic_schedule,
14
+ )
15
+
16
+ MOCHI_DIR = sys.argv[1]
17
+ assert Path(MOCHI_DIR).exists(), f"Model directory {MOCHI_DIR} does not exist."
18
+ pipeline = MochiSingleGPUPipeline(
19
+ text_encoder_factory=T5ModelFactory(),
20
+ dit_factory=DitModelFactory(model_path=f"{MOCHI_DIR}/dit.safetensors", model_dtype="bf16"),
21
+ decoder_factory=DecoderModelFactory(
22
+ model_path=f"{MOCHI_DIR}/vae.safetensors",
23
+ model_stats_path=f"{MOCHI_DIR}/vae_stats.json",
24
+ ),
25
+ cpu_offload=True,
26
+ decode_type="tiled_full",
27
+ )
28
+
29
+ PROMPT = dedent("""
30
+ A hand with delicate fingers picks up a bright yellow lemon from a wooden bowl
31
+ filled with lemons and sprigs of mint against a peach-colored background.
32
+ The hand gently tosses the lemon up and catches it, showcasing its smooth texture.
33
+ A beige string bag sits beside the bowl, adding a rustic touch to the scene.
34
+ Additional lemons, one halved, are scattered around the base of the bowl.
35
+ The even lighting enhances the vibrant colors and creates a fresh,
36
+ inviting atmosphere.
37
+ """)
38
+
39
+ video = pipeline(
40
+ height=480,
41
+ width=848,
42
+ num_frames=31,
43
+ num_inference_steps=64,
44
+ sigma_schedule=linear_quadratic_schedule(64, 0.025),
45
+ cfg_schedule=[4.5] * 64,
46
+ batch_cfg=False,
47
+ prompt=PROMPT,
48
+ negative_prompt="",
49
+ seed=12345,
50
+ )
51
+
52
+ with progress_bar(type="tqdm"):
53
+ save_video(video[0], "video.mp4")
demos/cli.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/env python
2
+ import json
3
+ import os
4
+ import time
5
+
6
+ import click
7
+ import numpy as np
8
+ import torch
9
+
10
+ from genmo.lib.progress import progress_bar
11
+ from genmo.lib.utils import save_video
12
+ from genmo.mochi_preview.pipelines import (
13
+ DecoderModelFactory,
14
+ DitModelFactory,
15
+ MochiMultiGPUPipeline,
16
+ MochiSingleGPUPipeline,
17
+ T5ModelFactory,
18
+ linear_quadratic_schedule,
19
+ )
20
+
21
+ pipeline = None
22
+ model_dir_path = None
23
+ lora_path = None
24
+ num_gpus = torch.cuda.device_count()
25
+ cpu_offload = False
26
+
27
+
28
+ def configure_model(model_dir_path_, lora_path_, cpu_offload_):
29
+ global model_dir_path, lora_path, cpu_offload
30
+ model_dir_path = model_dir_path_
31
+ lora_path = lora_path_
32
+ cpu_offload = cpu_offload_
33
+
34
+
35
+ def load_model():
36
+ global num_gpus, pipeline, model_dir_path, lora_path
37
+ if pipeline is None:
38
+ MOCHI_DIR = model_dir_path
39
+ print(f"Launching with {num_gpus} GPUs. If you want to force single GPU mode use CUDA_VISIBLE_DEVICES=0.")
40
+ klass = MochiSingleGPUPipeline if num_gpus == 1 else MochiMultiGPUPipeline
41
+ kwargs = dict(
42
+ text_encoder_factory=T5ModelFactory(),
43
+ dit_factory=DitModelFactory(
44
+ model_path=f"{MOCHI_DIR}/dit.safetensors",
45
+ lora_path=lora_path,
46
+ model_dtype="bf16",
47
+ ),
48
+ decoder_factory=DecoderModelFactory(
49
+ model_path=f"{MOCHI_DIR}/decoder.safetensors",
50
+ ),
51
+ )
52
+ if num_gpus > 1:
53
+ assert not lora_path, f"Lora not supported in multi-GPU mode"
54
+ assert not cpu_offload, "CPU offload not supported in multi-GPU mode"
55
+ kwargs["world_size"] = num_gpus
56
+ else:
57
+ kwargs["cpu_offload"] = cpu_offload
58
+ kwargs["decode_type"] = "tiled_spatial"
59
+ kwargs["fast_init"] = not lora_path
60
+ kwargs["strict_load"] = not lora_path
61
+ kwargs["decode_args"] = dict(overlap=8)
62
+ pipeline = klass(**kwargs)
63
+
64
+
65
+ def generate_video(
66
+ prompt,
67
+ negative_prompt,
68
+ width,
69
+ height,
70
+ num_frames,
71
+ seed,
72
+ cfg_scale,
73
+ num_inference_steps,
74
+ ):
75
+ load_model()
76
+
77
+ # sigma_schedule should be a list of floats of length (num_inference_steps + 1),
78
+ # such that sigma_schedule[0] == 1.0 and sigma_schedule[-1] == 0.0 and monotonically decreasing.
79
+ sigma_schedule = linear_quadratic_schedule(num_inference_steps, 0.025)
80
+
81
+ # cfg_schedule should be a list of floats of length num_inference_steps.
82
+ # For simplicity, we just use the same cfg scale at all timesteps,
83
+ # but more optimal schedules may use varying cfg, e.g:
84
+ # [5.0] * (num_inference_steps // 2) + [4.5] * (num_inference_steps // 2)
85
+ cfg_schedule = [cfg_scale] * num_inference_steps
86
+
87
+ args = {
88
+ "height": height,
89
+ "width": width,
90
+ "num_frames": num_frames,
91
+ "sigma_schedule": sigma_schedule,
92
+ "cfg_schedule": cfg_schedule,
93
+ "num_inference_steps": num_inference_steps,
94
+ # We *need* flash attention to batch cfg
95
+ # and it's only worth doing in a high-memory regime (assume multiple GPUs)
96
+ "batch_cfg": False,
97
+ "prompt": prompt,
98
+ "negative_prompt": negative_prompt,
99
+ "seed": seed,
100
+ }
101
+
102
+ with progress_bar(type="tqdm"):
103
+ final_frames = pipeline(**args)
104
+
105
+ final_frames = final_frames[0]
106
+
107
+ assert isinstance(final_frames, np.ndarray)
108
+ assert final_frames.dtype == np.float32
109
+
110
+ os.makedirs("outputs", exist_ok=True)
111
+ output_path = os.path.join("outputs", f"output_{int(time.time())}.mp4")
112
+
113
+ save_video(final_frames, output_path)
114
+ json_path = os.path.splitext(output_path)[0] + ".json"
115
+ json.dump(args, open(json_path, "w"), indent=4)
116
+
117
+ return output_path
118
+
119
+
120
+ from textwrap import dedent
121
+
122
+ DEFAULT_PROMPT = dedent("""
123
+ A hand with delicate fingers picks up a bright yellow lemon from a wooden bowl
124
+ filled with lemons and sprigs of mint against a peach-colored background.
125
+ The hand gently tosses the lemon up and catches it, showcasing its smooth texture.
126
+ A beige string bag sits beside the bowl, adding a rustic touch to the scene.
127
+ Additional lemons, one halved, are scattered around the base of the bowl.
128
+ The even lighting enhances the vibrant colors and creates a fresh,
129
+ inviting atmosphere.
130
+ """)
131
+
132
+
133
+ @click.command()
134
+ @click.option("--prompt", default=DEFAULT_PROMPT, help="Prompt for video generation.")
135
+ @click.option("--negative_prompt", default="", help="Negative prompt for video generation.")
136
+ @click.option("--width", default=848, type=int, help="Width of the video.")
137
+ @click.option("--height", default=480, type=int, help="Height of the video.")
138
+ @click.option("--num_frames", default=163, type=int, help="Number of frames.")
139
+ @click.option("--seed", default=1710977262, type=int, help="Random seed.")
140
+ @click.option("--cfg_scale", default=6.0, type=float, help="CFG Scale.")
141
+ @click.option("--num_steps", default=64, type=int, help="Number of inference steps.")
142
+ @click.option("--model_dir", required=True, help="Path to the model directory.")
143
+ @click.option("--lora_path", required=False, help="Path to the lora file.")
144
+ @click.option("--cpu_offload", is_flag=True, help="Whether to offload model to CPU")
145
+ def generate_cli(
146
+ prompt, negative_prompt, width, height, num_frames, seed, cfg_scale, num_steps, model_dir, lora_path, cpu_offload
147
+ ):
148
+ configure_model(model_dir, lora_path, cpu_offload)
149
+ output = generate_video(
150
+ prompt,
151
+ negative_prompt,
152
+ width,
153
+ height,
154
+ num_frames,
155
+ seed,
156
+ cfg_scale,
157
+ num_steps,
158
+ )
159
+ click.echo(f"Video generated at: {output}")
160
+
161
+
162
+ if __name__ == "__main__":
163
+ generate_cli()
demos/comfyui_nodes.py ADDED
File without changes
demos/fine_tuner/README.md ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Mochi 1 LoRA Fine-tuner
2
+
3
+ ![Mochi being made](../../assets/mochi-factory.webp)
4
+
5
+
6
+ This folder contains tools for fine-tuning the Mochi 1 model. It supports [LoRA](https://arxiv.org/abs/2106.09685) fine-tuning on a single GPU.
7
+
8
+ ## Quick Start (Single GPU)
9
+ This shows you how to prepare your dataset for single GPU.
10
+
11
+ First, setup the inference code and download Mochi 1 weights following [README.md](../../README.md).
12
+ All commands below assume you are in the top-level directory of the Mochi repo.
13
+
14
+ ### 1. Collect your videos and captions
15
+ Collect your videos (supported formats: MP4, MOV) into a folder, e.g. `videos/`. Then, write a detailed description of each of the videos in a txt file with the same name. For example,
16
+ ```
17
+ videos/
18
+ video_1.mp4
19
+ video_1.txt -- One-paragraph description of video_1
20
+ video_2.mp4
21
+ video_2.txt -- One-paragraph description of video_2
22
+ ...
23
+ ```
24
+
25
+ ### 2. Process videos and captions (About 2 minutes)
26
+ Update the paths in the command below to match your dataset. Videos are processed at 30 FPS, so make sure your videos are at least `num_frames / 30` seconds long.
27
+ ```bash
28
+ bash demos/fine_tuner/preprocess.bash -v videos/ -o videos_prepared/ -w weights/ --num_frames 37
29
+ ```
30
+
31
+ ### 3. Fine-tune the model
32
+ Update `./demos/fine_tuner/configs/lora.yaml` to customize the fine-tuning process,
33
+ including prompts to generate at various points of the fine-tuning process and the path to your prepared videos.
34
+
35
+ Launch LoRA fine-tuning on single GPU:
36
+ ```bash
37
+ bash ./demos/fine_tuner/run.bash -c ./demos/fine_tuner/configs/lora.yaml -n 1
38
+ ```
39
+
40
+ Samples will be generated in `finetunes/my_mochi_lora/samples` every 200 steps.
41
+
42
+ ### 4. Use your fine-tuned weights to generate videos!
43
+ Update `--lora_path` to the path of your fine-tuned weights and run:
44
+ ```python
45
+ python3 ./demos/cli.py --model_dir weights/ --lora_path finetunes/my_mochi_lora/model_2000.lora.safetensors --num_frames 37 --cpu_offload --prompt "A delicate porcelain teacup sits on a marble countertop. The teacup suddenly shatters into hundreds of white ceramic shards that scatter through the air. The scene is bright and crisp with dramatic lighting."
46
+ ```
47
+
48
+ You can increase the number of frames to generate a longer video. Finally, share your creations with the community by uploading your LoRA and sample videos to Hugging Face.
49
+
50
+ ## System Requirements
51
+
52
+ **Single GPU:**
53
+ - 1x H100 or A100 (80 GB VRAM is recommended)
54
+ - Less VRAM is required if training with less than 1 second long videos.
55
+
56
+ **Supported video lengths:** Up to 85 frames (~2.8 seconds at 30 FPS)
57
+ - Choose a frame count in increments of 6: 25, 31, 37, ... 79, 85.
58
+ - Training on 37 frames uses 50 GB of VRAM. On 1 H100, each training step takes about 1.67 s/it,
59
+ and you'll start seeing changes to your videos within 200-400 steps. Training for 1,000 steps takes about 30 minutes.
60
+
61
+ Settings tested on 1x H100 SXM:
62
+
63
+ | Frames | Video Length | VRAM | Time/step | num_qkv_checkpoint | num_ff_checkpoint | num_post_attn_checkpoint |
64
+ |--------|--------------|------|-----------|-------------------|-------------------|-------------------------|
65
+ | 37 frames | 1.2 second videos | 50 GB VRAM | 1.67 s/it | 48 | 48† | 48 |
66
+ | 61 frames | 2.0 second videos | 64 GB VRAM | 3.35 s/it | 48 | 48† | 48 |
67
+ | 79 frames | 2.6 second videos | 69-78 GB VRAM | 4.92 s/it | 48 | 48† | 48 |
68
+ | 85 frames | 2.8 second videos | 80 GB VRAM | 5.44 s/it | 48 | 48 | 48 |
69
+
70
+ *† As the VRAM is not fully used, you can lower `num_ff_checkpoint` to speed up training.*
71
+
72
+ ## Technical Details
73
+
74
+ - LoRA fine-tuning updates the query, key, and value projection matrices, as well as the output projection matrix.
75
+ These settings are configurable in `./demos/fine_tuner/configs/lora.yaml`.
76
+ - We welcome contributions and suggestions for improved settings.
77
+
78
+ ## Known Limitations
79
+
80
+ - No support for training on multiple GPUs
81
+ - LoRA inference is restricted to 1-GPU (for now)
82
+
83
+ ## Tips
84
+
85
+ - Be as descriptive as possible in your captions.
86
+ - A learning rate around 1e-4 or 2e-4 seems effective for LoRA fine-tuning.
87
+ - For larger datasets or to customize the model aggressively, increase `num_steps` in in the YAML.
88
+ - To monitor training loss, uncomment the `wandb` section in the YAML and run `wandb login` or set the `WANDB_API_KEY` environment variable.
89
+ - Videos are trimmed to the **first** `num_frames` frames. Make sure your clips contain the content you care about near the beginning.
90
+ You can check the trimmed versions after running `preprocess.bash` to make sure they look good.
91
+ - When capturing HDR videos on an iPhone, convert your .mov files to .mp4 using the Handbrake application. Our preprocessing script won't produce the correct colorspace otherwise, and your fine-tuned videos may look overly bright.
92
+
93
+ ### If you are running out of GPU memory, make sure:
94
+ - `COMPILE_DIT=1` is set in `demos/fine_tuner/run.bash`.
95
+ This enables model compilation, which saves memory and speeds up training!
96
+ - `num_post_attn_checkpoint`, `num_ff_checkpoint`, and `num_qkv_checkpoint` are set to 48 in your YAML.
97
+ You can checkpoint up to 48 layers, saving memory at the cost of slower training.
98
+ - If all else fails, reduce `num_frames` when processing your videos and in your YAML.
99
+ You can fine-tune Mochi on shorter videos, and still generate longer videos at inference time.
demos/fine_tuner/configs/lora.yaml ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ init_checkpoint_path: weights/dit.safetensors
2
+ checkpoint_dir: finetunes/my_mochi_lora
3
+ train_data_dir: videos_prepared
4
+ attention_mode: sdpa
5
+ single_video_mode: false # Useful for debugging whether your model can learn a single video
6
+
7
+ # You only need this if you're using wandb
8
+ wandb:
9
+ # project: mochi_1_lora
10
+ # name: ${checkpoint_dir}
11
+ # group: null
12
+
13
+ optimizer:
14
+ lr: 2e-4
15
+ weight_decay: 0.01
16
+
17
+ model:
18
+ type: lora
19
+ kwargs:
20
+ # Apply LoRA to the QKV projection and the output projection of the attention block.
21
+ qkv_proj_lora_rank: 16
22
+ qkv_proj_lora_alpha: 16
23
+ qkv_proj_lora_dropout: 0.
24
+ out_proj_lora_rank: 16
25
+ out_proj_lora_alpha: 16
26
+ out_proj_lora_dropout: 0.
27
+
28
+ training:
29
+ model_dtype: bf16
30
+ warmup_steps: 200
31
+ num_qkv_checkpoint: 48
32
+ num_ff_checkpoint: 48
33
+ num_post_attn_checkpoint: 48
34
+ num_steps: 2000
35
+ save_interval: 200
36
+ caption_dropout: 0.1
37
+ grad_clip: 0.0
38
+ save_safetensors: true
39
+
40
+ # Used for generating samples during training to monitor progress ...
41
+ sample:
42
+ interval: 200
43
+ output_dir: ${checkpoint_dir}/samples
44
+ decoder_path: weights/decoder.safetensors
45
+ prompts:
46
+ - A pristine snowglobe featuring a winter scene sits peacefully. The globe violently explodes, sending glass, water, and glittering fake snow in all directions. The scene is captured with high-speed photography.
47
+ - A vintage pocket watch ticks quietly on an antique desk. Suddenly, it explodes into gears, springs and metal fragments that scatter through the air. The scene is richly detailed with warm, brass tones.
48
+ - A cello is propped up against a wall, a single spotlight illuminating it. The cello explodes into wooden fragments, sending debris everywhere. The scene is vibrant and colorful.
49
+ - A graphics card sits inside an oven, heatwaves around it. Suddenly, the graphics card explodes into numerous fragments, sending debris everywhere. The scene is darkly lit, high contrast, with a focus on the shattered pieces.
50
+ - A delicate porcelain teacup sits on a marble countertop. The teacup suddenly shatters into hundreds of white ceramic shards that scatter through the air. The scene is bright and crisp with dramatic lighting.
51
+ seed: 12345
52
+ kwargs:
53
+ height: 480
54
+ width: 848
55
+ num_frames: 37
56
+ num_inference_steps: 64
57
+ sigma_schedule_python_code: "linear_quadratic_schedule(64, 0.025)"
58
+ cfg_schedule_python_code: "[6.0] * 64"
demos/fine_tuner/dataset.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import click
4
+ import torch
5
+ from torch.utils.data import DataLoader, Dataset
6
+
7
+
8
+ def load_to_cpu(x):
9
+ return torch.load(x, map_location=torch.device("cpu"), weights_only=True)
10
+
11
+
12
+ class LatentEmbedDataset(Dataset):
13
+ def __init__(self, file_paths, repeat=1):
14
+ self.items = [
15
+ (Path(p).with_suffix(".latent.pt"), Path(p).with_suffix(".embed.pt"))
16
+ for p in file_paths
17
+ if Path(p).with_suffix(".latent.pt").is_file() and Path(p).with_suffix(".embed.pt").is_file()
18
+ ]
19
+ self.items = self.items * repeat
20
+ print(f"Loaded {len(self.items)}/{len(file_paths)} valid file pairs.")
21
+
22
+ def __len__(self):
23
+ return len(self.items)
24
+
25
+ def __getitem__(self, idx):
26
+ latent_path, embed_path = self.items[idx]
27
+ return load_to_cpu(latent_path), load_to_cpu(embed_path)
28
+
29
+
30
+ @click.command()
31
+ @click.argument("directory", type=click.Path(exists=True, file_okay=False))
32
+ def process_videos(directory):
33
+ dir_path = Path(directory)
34
+ mp4_files = [str(f) for f in dir_path.glob("**/*.mp4") if not f.name.endswith(".recon.mp4")]
35
+ assert mp4_files, f"No mp4 files found"
36
+
37
+ dataset = LatentEmbedDataset(mp4_files)
38
+ dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
39
+
40
+ for latents, embeds in dataloader:
41
+ print([(k, v.shape) for k, v in latents.items()])
42
+
43
+
44
+ if __name__ == "__main__":
45
+ process_videos()
demos/fine_tuner/embed_captions.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/env python3
2
+ from pathlib import Path
3
+
4
+ import click
5
+ import torch
6
+ from tqdm import tqdm
7
+ from transformers import T5Tokenizer
8
+
9
+ from genmo.mochi_preview.pipelines import T5_MODEL, T5ModelFactory, get_conditioning_for_prompts
10
+
11
+
12
+ @click.command()
13
+ @click.argument("captions_dir", type=click.Path(exists=True, file_okay=False, dir_okay=True, path_type=Path))
14
+ @click.option("--device_id", default=0, help="GPU device ID to use")
15
+ @click.option("--overwrite", "-ow", is_flag=True, help="Overwrite existing embeddings")
16
+ def process_captions(captions_dir: Path, device_id: int, overwrite=True) -> None:
17
+ """Process all text files in a directory using T5 encoder.
18
+
19
+ Args:
20
+ captions_dir: Directory containing input text files
21
+ device_id: GPU device ID to use
22
+ """
23
+
24
+ torch.backends.cuda.matmul.allow_tf32 = True
25
+ torch.backends.cudnn.allow_tf32 = True
26
+
27
+ # Get all text file paths
28
+ text_paths = list(captions_dir.glob("**/*.txt"))
29
+ if not text_paths:
30
+ print(f"No text files found in {captions_dir}")
31
+ return
32
+
33
+ # Initialize model and tokenizer
34
+ model_factory = T5ModelFactory()
35
+ device = f"cuda:{device_id}"
36
+ model = model_factory.get_model(local_rank=0, device_id=device_id, world_size=1)
37
+ tokenizer = T5Tokenizer.from_pretrained(T5_MODEL, legacy=False)
38
+
39
+ with tqdm(total=len(text_paths)) as pbar:
40
+ for text_path in text_paths:
41
+ embed_path = text_path.with_suffix(".embed.pt")
42
+ if embed_path.exists() and not overwrite:
43
+ pbar.write(f"Skipping {text_path} - embeddings already exist")
44
+ continue
45
+
46
+ pbar.write(f"Processing {text_path}")
47
+ try:
48
+ with open(text_path) as f:
49
+ text = f.read().strip()
50
+
51
+ with torch.inference_mode():
52
+ conditioning = get_conditioning_for_prompts(tokenizer, model, device, [text])
53
+
54
+ torch.save(conditioning, embed_path)
55
+
56
+ except Exception as e:
57
+ import traceback
58
+
59
+ traceback.print_exc()
60
+ pbar.write(f"Error processing {text_path}: {str(e)}")
61
+
62
+ pbar.update(1)
63
+
64
+
65
+ if __name__ == "__main__":
66
+ process_captions()
demos/fine_tuner/encode_videos.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/env python3
2
+ import os
3
+ from pathlib import Path
4
+ import traceback
5
+ from typing import Optional
6
+
7
+ import click
8
+ import ray
9
+ import torch
10
+ import torchvision
11
+ from einops import rearrange
12
+
13
+ import genmo.mochi_preview.dit.joint_model.context_parallel as cp
14
+ import genmo.mochi_preview.vae.cp_conv as cp_conv
15
+ from genmo.lib.progress import get_new_progress_bar, progress_bar
16
+ from genmo.lib.utils import Timer, save_video
17
+ from genmo.mochi_preview.pipelines import DecoderModelFactory, EncoderModelFactory
18
+ from genmo.mochi_preview.vae.models import add_fourier_features, decode_latents
19
+
20
+
21
+ class GPUContext:
22
+ def __init__(
23
+ self,
24
+ *,
25
+ encoder_factory: Optional[EncoderModelFactory] = None,
26
+ decoder_factory: Optional[DecoderModelFactory] = None,
27
+ ):
28
+ t = Timer()
29
+ self.device = torch.device(f"cuda")
30
+ if encoder_factory is not None:
31
+ with t("load_encoder"):
32
+ self.encoder = encoder_factory.get_model()
33
+ if decoder_factory is not None:
34
+ with t("load_decoder"):
35
+ self.decoder = decoder_factory.get_model()
36
+ t.print_stats()
37
+
38
+
39
+ def preprocess(ctx: GPUContext, vid_path: Path, shape: str, reconstruct: bool):
40
+ T, H, W = [int(s) for s in shape.split("x")]
41
+ assert (T - 1) % 6 == 0, "Expected T to be 1 mod 6"
42
+ video, _, metadata = torchvision.io.read_video(
43
+ str(vid_path), output_format="THWC", pts_unit="secs")
44
+ fps = metadata["video_fps"]
45
+ video = rearrange(video, "t h w c -> c t h w")
46
+ og_shape = video.shape
47
+ assert video.shape[2] == H, f"Expected {vid_path} to have height {H}, got {video.shape}"
48
+ assert video.shape[3] == W, f"Expected {vid_path} to have width {W}, got {video.shape}"
49
+ assert video.shape[1] >= T, f"Expected {vid_path} to have at least {T} frames, got {video.shape}"
50
+ if video.shape[1] > T:
51
+ video = video[:, :T]
52
+ print(f"Trimmed video from {og_shape[1]} to first {T} frames")
53
+ video = video.unsqueeze(0)
54
+ video = video.float() / 127.5 - 1.0
55
+ video = video.to(ctx.device)
56
+ video = add_fourier_features(video)
57
+
58
+ assert video.ndim == 5
59
+ video = cp.local_shard(video, dim=2) # split along time dimension
60
+
61
+ with torch.inference_mode():
62
+ with torch.autocast("cuda", dtype=torch.bfloat16):
63
+ ldist = ctx.encoder(video)
64
+
65
+ print(f"{og_shape} -> {ldist.mean.shape}")
66
+ torch.save(
67
+ dict(mean=ldist.mean, logvar=ldist.logvar),
68
+ vid_path.with_suffix(".latent.pt"),
69
+ )
70
+
71
+ if reconstruct:
72
+ latents = ldist.sample()
73
+ frames = decode_latents(ctx.decoder, latents)
74
+ frames = frames.cpu().numpy()
75
+ save_video(frames[0], str(vid_path.with_suffix(".recon.mp4")), fps=fps)
76
+
77
+
78
+ @click.command()
79
+ @click.argument("videos_dir", type=click.Path(exists=True, file_okay=False, dir_okay=True, path_type=Path))
80
+ @click.option(
81
+ "--model_dir",
82
+ type=click.Path(exists=True, file_okay=False, dir_okay=True, path_type=Path),
83
+ help="Path to folder containing Mochi's VAE encoder and decoder weights. Download from Hugging Face: https://huggingface.co/genmo/mochi-1-preview/blob/main/encoder.safetensors and https://huggingface.co/genmo/mochi-1-preview/blob/main/decoder.safetensors",
84
+ default="weights/",
85
+ )
86
+ @click.option("--num_gpus", default=1, help="Number of GPUs to split the encoder over")
87
+ @click.option(
88
+ "--recon_interval", default=10, help="Reconstruct one out of every N videos (0 to disable reconstruction)"
89
+ )
90
+ @click.option("--shape", default="163x480x848", help="Shape of the video to encode")
91
+ @click.option("--overwrite", "-ow", is_flag=True, help="Overwrite existing latents")
92
+ def batch_process(
93
+ videos_dir: Path, model_dir: Path, num_gpus: int, recon_interval: int, shape: str, overwrite: bool
94
+ ) -> None:
95
+ """Process all videos in a directory using multiple GPUs.
96
+
97
+ Args:
98
+ videos_dir: Directory containing input videos
99
+ encoder_path: Path to encoder model weights
100
+ decoder_path: Path to decoder model weights
101
+ num_gpus: Number of GPUs to use for parallel processing
102
+ recon_interval: Frequency of video reconstructions (0 to disable)
103
+ """
104
+
105
+ torch.backends.cuda.matmul.allow_tf32 = True
106
+ torch.backends.cudnn.allow_tf32 = True
107
+
108
+ # Get all video paths
109
+ video_paths = list(videos_dir.glob("**/*.mp4"))
110
+ if not video_paths:
111
+ print(f"No MP4 files found in {videos_dir}")
112
+ return
113
+
114
+ preproc = GPUContext(
115
+ encoder_factory=EncoderModelFactory(model_path=os.path.join(model_dir, "encoder.safetensors")),
116
+ decoder_factory=DecoderModelFactory(model_path=os.path.join(model_dir, "decoder.safetensors")),
117
+ )
118
+ with progress_bar(type="ray_tqdm"):
119
+ for idx, video_path in get_new_progress_bar((list(enumerate(sorted(video_paths))))):
120
+ if str(video_path).endswith(".recon.mp4"):
121
+ print(f"Skipping {video_path} b/c it is a reconstruction")
122
+ continue
123
+
124
+ print(f"Processing {video_path}")
125
+ try:
126
+ if video_path.with_suffix(".latent.pt").exists() and not overwrite:
127
+ print(f"Skipping {video_path}")
128
+ continue
129
+
130
+ preprocess(
131
+ ctx=preproc,
132
+ vid_path=video_path,
133
+ shape=shape,
134
+ reconstruct=recon_interval != 0 and idx % recon_interval == 0,
135
+ )
136
+ except Exception as e:
137
+ traceback.print_exc()
138
+ print(f"Error processing {video_path}: {str(e)}")
139
+
140
+
141
+ if __name__ == "__main__":
142
+ batch_process()
demos/fine_tuner/preprocess.bash ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /bin/bash
2
+
3
+ # Enable job control and set process group
4
+ set -eo pipefail
5
+ set -x
6
+
7
+ # Function to display help
8
+ usage() {
9
+ echo "Usage: $0 -v|--videos_dir videos_dir -o|--output_dir output_dir -w|--weights_dir weights_dir -n|--num_frames num_frames"
10
+ echo " -v, --videos_dir Path to the videos directory"
11
+ echo " -o, --output_dir Path to the output directory"
12
+ echo " -w, --weights_dir Path to the weights directory"
13
+ echo " -n, --num_frames Number of frames"
14
+ exit 1
15
+ }
16
+
17
+ # Function to check if the next argument is missing
18
+ check_argument() {
19
+ if [[ -z "$2" || "$2" == -* ]]; then
20
+ echo "Error: Argument for $1 is missing"
21
+ usage
22
+ fi
23
+ }
24
+
25
+ # Parse command-line arguments
26
+ while [[ "$#" -gt 0 ]]; do
27
+ case $1 in
28
+ -v|--videos_dir) check_argument "$1" "$2"; VIDEOS_DIR="$2"; shift ;;
29
+ -o|--output_dir) check_argument "$1" "$2"; OUTPUT_DIR="$2"; shift ;;
30
+ -w|--weights_dir) check_argument "$1" "$2"; WEIGHTS_DIR="$2"; shift ;;
31
+ -n|--num_frames) check_argument "$1" "$2"; NUM_FRAMES="$2"; shift ;;
32
+ -h|--help) usage ;;
33
+ *) echo "Unknown parameter passed: $1"; usage ;;
34
+ esac
35
+ shift
36
+ done
37
+
38
+ # Check if all required arguments are provided
39
+ if [[ -z "$VIDEOS_DIR" || -z "$OUTPUT_DIR" || -z "$WEIGHTS_DIR" || -z "$NUM_FRAMES" ]]; then
40
+ echo "Error: All arguments are required."
41
+ usage
42
+ fi
43
+
44
+ # Get the directory where this script is located
45
+ SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
46
+ echo "Using script directory: ${SCRIPT_DIR}"
47
+
48
+ ##### Step 1: Trim and resize videos
49
+ echo -e "\n\e[1;35m🎬 **Step 1: Trim and resize videos** \e[0m"
50
+ # Calculate duration to trim videos
51
+ DURATION=$(printf "%.1f" "$(echo "($NUM_FRAMES / 30) + 0.09" | bc -l)")
52
+ echo "Trimming videos to duration: ${DURATION} seconds"
53
+ python3 ${SCRIPT_DIR}/trim_and_crop_videos.py ${VIDEOS_DIR} ${OUTPUT_DIR} -d ${DURATION}
54
+
55
+ ##### Step 2: Run the VAE encoder on each video.
56
+ echo -e "\n\e[1;35m🎥 **Step 2: Run the VAE encoder on each video** \e[0m"
57
+ python3 ${SCRIPT_DIR}/encode_videos.py ${OUTPUT_DIR} \
58
+ --model_dir ${WEIGHTS_DIR} --num_gpus 1 --shape "${NUM_FRAMES}x480x848" --overwrite
59
+
60
+ ##### Step 3: Compute T5 embeddings
61
+ echo -e "\n\e[1;35m🧠 **Step 3: Compute T5 embeddings** \e[0m"
62
+ python3 ${SCRIPT_DIR}/embed_captions.py --overwrite ${OUTPUT_DIR}
63
+
64
+ echo -e "\n\e[1;32m✓ Done!\e[0m"
demos/fine_tuner/run.bash ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /bin/bash
2
+
3
+ # Enable job control and set process group
4
+ set -m
5
+ trap 'kill $(jobs -p)' EXIT INT TERM
6
+
7
+ # Get the directory where this script is located
8
+ SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
9
+ DEFAULT_CONFIG="${SCRIPT_DIR}/configs/finetune.yaml"
10
+
11
+ # Parse command line arguments
12
+ usage() {
13
+ echo "Usage: $0 [-c|--config <config_path>] [-n|--num-gpus <num_gpus>]"
14
+ echo " -c, --config Path to config file (default: ${DEFAULT_CONFIG})"
15
+ echo " -n, --num-gpus Number of GPUs to use (default: 8)"
16
+ exit 1
17
+ }
18
+
19
+ # Default values
20
+ CONFIG_PATH="${DEFAULT_CONFIG}"
21
+ NUM_GPUS=8
22
+
23
+ # Parse arguments
24
+ while [[ $# -gt 0 ]]; do
25
+ case $1 in
26
+ -c|--config)
27
+ CONFIG_PATH="$2"
28
+ shift 2
29
+ ;;
30
+ -n|--num-gpus)
31
+ NUM_GPUS="$2"
32
+ shift 2
33
+ ;;
34
+ -h|--help)
35
+ usage
36
+ ;;
37
+ *)
38
+ echo "Unknown option: $1"
39
+ usage
40
+ ;;
41
+ esac
42
+ done
43
+
44
+ # Validate config file exists
45
+ if [ ! -f "${CONFIG_PATH}" ]; then
46
+ echo "Config file not found at ${CONFIG_PATH}"
47
+ exit 1
48
+ fi
49
+
50
+ # Validate num_gpus is a positive integer
51
+ if ! [[ "$NUM_GPUS" =~ ^[1-9][0-9]*$ ]]; then
52
+ echo "Number of GPUs must be a positive integer"
53
+ exit 1
54
+ fi
55
+
56
+ # Set distributed training environment variables
57
+ export MASTER_PORT=29500
58
+ export MASTER_ADDR="localhost"
59
+ export WORLD_SIZE=$NUM_GPUS
60
+ export TF_CPP_MIN_LOG_LEVEL=3
61
+ export COMPILE_DIT=1
62
+
63
+ # Set IS_DISTRIBUTED based on NUM_GPUS
64
+ if [ "$NUM_GPUS" -gt 1 ]; then
65
+ export IS_DISTRIBUTED=true
66
+ fi
67
+
68
+ # Load .env file (if it exists)
69
+ if [ -f ".env" ]; then
70
+ export $(grep -v '^#' .env | xargs)
71
+ fi
72
+
73
+ echo "Starting training with ${NUM_GPUS} GPU(s), mode: ${IS_DISTRIBUTED:+distributed}${IS_DISTRIBUTED:-single_gpu}"
74
+ echo "Using config: ${CONFIG_PATH}"
75
+
76
+ # Launch processes
77
+ if [ "$NUM_GPUS" -gt 1 ]; then
78
+ for RANK in $(seq 0 $((NUM_GPUS-1))); do
79
+ env RANK=$RANK CUDA_VISIBLE_DEVICES=$RANK python "${SCRIPT_DIR}/train.py" --config-path "${CONFIG_PATH}" &
80
+ done
81
+ else
82
+ python "${SCRIPT_DIR}/train.py" --config-path "${CONFIG_PATH}" &
83
+ fi
84
+
85
+ # Wait for all background processes to complete
86
+ wait
87
+
88
+ # Check if any process failed
89
+ if [ $? -ne 0 ]; then
90
+ echo "One or more training processes failed"
91
+ exit 1
92
+ fi
demos/fine_tuner/train.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import multiprocessing as mp
3
+ import os
4
+ import random
5
+ import re
6
+ import sys
7
+ import time
8
+ from contextlib import contextmanager
9
+ from glob import glob
10
+ from pathlib import Path
11
+ from typing import Any, Dict, Tuple, cast
12
+
13
+ import click
14
+ import numpy as np
15
+ from omegaconf import DictConfig, ListConfig, OmegaConf
16
+ from safetensors.torch import save_file
17
+ import torch
18
+ from torch import Tensor
19
+ from torch.distributed.checkpoint.state_dict import StateDictOptions, get_state_dict
20
+ import torch.nn.functional as F
21
+ from tqdm import tqdm
22
+
23
+ torch._dynamo.config.cache_size_limit = 32
24
+ torch.backends.cuda.matmul.allow_tf32 = True
25
+ torch.backends.cudnn.allow_tf32 = True
26
+ torch.use_deterministic_algorithms(False)
27
+
28
+ import genmo.mochi_preview.dit.joint_model.lora as lora
29
+ from genmo.lib.progress import progress_bar
30
+ from genmo.lib.utils import Timer, save_video
31
+ from genmo.mochi_preview.pipelines import (
32
+ DecoderModelFactory,
33
+ DitModelFactory,
34
+ ModelFactory,
35
+ T5ModelFactory,
36
+ cast_dit,
37
+ compute_packed_indices,
38
+ get_conditioning,
39
+ linear_quadratic_schedule, # used in eval'd Python code in lora.yaml
40
+ load_to_cpu,
41
+ move_to_device,
42
+ sample_model,
43
+ t5_tokenizer,
44
+ )
45
+ from genmo.mochi_preview.vae.latent_dist import LatentDistribution
46
+ from genmo.mochi_preview.vae.models import decode_latents_tiled_spatial
47
+
48
+ sys.path.append("..")
49
+
50
+ from dataset import LatentEmbedDataset
51
+
52
+
53
+ class MochiTorchRunEvalPipeline:
54
+ def __init__(
55
+ self,
56
+ *,
57
+ device_id,
58
+ dit,
59
+ text_encoder_factory: ModelFactory,
60
+ decoder_factory: ModelFactory,
61
+ ):
62
+ self.device = torch.device(f"cuda:{device_id}")
63
+ self.tokenizer = t5_tokenizer()
64
+ t = Timer()
65
+ self.dit = dit
66
+ with t("load_text_encoder"):
67
+ self.text_encoder = text_encoder_factory.get_model(
68
+ local_rank=0,
69
+ world_size=1,
70
+ device_id="cpu",
71
+ )
72
+ with t("load_vae"):
73
+ self.decoder = decoder_factory.get_model(local_rank=0, device_id="cpu", world_size=1)
74
+ t.print_stats() # type: ignore
75
+
76
+ def __call__(self, prompt, save_path, **kwargs):
77
+ with progress_bar(type="tqdm", enabled=True), torch.inference_mode():
78
+ # Encode prompt with T5 XXL.
79
+ with move_to_device(self.text_encoder, self.device, enabled=True):
80
+ conditioning = get_conditioning(
81
+ self.tokenizer,
82
+ self.text_encoder,
83
+ self.device,
84
+ batch_inputs=False,
85
+ prompt=prompt,
86
+ negative_prompt="",
87
+ )
88
+
89
+ # Sample video latents from Mochi.
90
+ with move_to_device(self.dit, self.device, enabled=True):
91
+ latents = sample_model(self.device, self.dit, conditioning, **kwargs)
92
+
93
+ # Decode video latents to frames.
94
+ with move_to_device(self.decoder, self.device, enabled=True):
95
+ frames = decode_latents_tiled_spatial(
96
+ self.decoder, latents, num_tiles_w=2, num_tiles_h=2, overlap=8)
97
+ frames = frames.cpu().numpy() # b t h w c
98
+ assert isinstance(frames, np.ndarray)
99
+
100
+ save_video(frames[0], save_path)
101
+
102
+
103
+ def map_to_device(x, device: torch.device):
104
+ if isinstance(x, dict):
105
+ return {k: map_to_device(v, device) for k, v in x.items()}
106
+ elif isinstance(x, list):
107
+ return [map_to_device(y, device) for y in x]
108
+ elif isinstance(x, tuple):
109
+ return tuple(map_to_device(y, device) for y in x)
110
+ elif isinstance(x, torch.Tensor):
111
+ return x.to(device, non_blocking=True)
112
+ else:
113
+ return x
114
+
115
+
116
+ EPOCH_IDX = 0
117
+
118
+
119
+ def infinite_dl(dl):
120
+ global EPOCH_IDX
121
+ while True:
122
+ EPOCH_IDX += 1
123
+ for batch in dl:
124
+ yield batch
125
+
126
+
127
+ @contextmanager
128
+ def timer(description="Task", enabled=True):
129
+ if enabled:
130
+ start = time.perf_counter()
131
+ try:
132
+ yield
133
+ finally:
134
+ if enabled:
135
+ elapsed = time.perf_counter() - start # type: ignore
136
+ print(f"{description} took {elapsed:.4f} seconds")
137
+
138
+
139
+ def get_cosine_annealing_lr_scheduler(
140
+ optimizer: torch.optim.Optimizer,
141
+ warmup_steps: int,
142
+ total_steps: int,
143
+ ):
144
+ def lr_lambda(step):
145
+ if step < warmup_steps:
146
+ return float(step) / float(max(1, warmup_steps))
147
+ else:
148
+ return 0.5 * (1 + np.cos(np.pi * (step - warmup_steps) / (total_steps - warmup_steps)))
149
+
150
+ return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
151
+
152
+
153
+ @click.command()
154
+ @click.option("--config-path", type=click.Path(exists=True), required=True, help="Path to YAML config file")
155
+ def main(config_path):
156
+ mp.set_start_method("spawn", force=True)
157
+ cfg = cast(DictConfig, OmegaConf.load(config_path))
158
+
159
+ device_id = 0
160
+ device_str = f"cuda:0"
161
+ device = torch.device(device_str)
162
+
163
+ # Verify checkpoint path exists
164
+ checkpoint_path = Path(cfg.init_checkpoint_path)
165
+ assert checkpoint_path.exists(), f"Checkpoint file not found: {checkpoint_path}"
166
+
167
+ # Create checkpoint directory if it doesn't exist
168
+ checkpoint_dir = Path(cfg.checkpoint_dir)
169
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
170
+
171
+ # Get step number from checkpoint filename
172
+ pattern = r"model_(\d+)\.(lora|checkpoint)\.(safetensors|pt)"
173
+ match = re.search(pattern, str(checkpoint_path))
174
+ if match:
175
+ start_step_num = int(match.group(1))
176
+ opt_path = str(checkpoint_path).replace("model_", "optimizer_")
177
+ else:
178
+ start_step_num = 0
179
+ opt_path = ""
180
+
181
+ print(
182
+ f"model={checkpoint_path}, optimizer={opt_path}, start_step_num={start_step_num}"
183
+ )
184
+
185
+ wandb_run = None
186
+ sample_prompts = cfg.sample.prompts
187
+
188
+ train_vids = list(sorted(glob(f"{cfg.train_data_dir}/*.mp4")))
189
+ train_vids = [v for v in train_vids if not v.endswith(".recon.mp4")]
190
+ print(f"Found {len(train_vids)} training videos in {cfg.train_data_dir}")
191
+ assert len(train_vids) > 0, f"No training data found in {cfg.train_data_dir}"
192
+ if cfg.single_video_mode:
193
+ train_vids = train_vids[:1]
194
+ sample_prompts = [Path(train_vids[0]).with_suffix(".txt").read_text()]
195
+ print(f"Training on video: {train_vids[0]}")
196
+
197
+ train_dataset = LatentEmbedDataset(
198
+ train_vids,
199
+ repeat=1_000 if cfg.single_video_mode else 1,
200
+ )
201
+ train_dl = torch.utils.data.DataLoader(
202
+ train_dataset,
203
+ batch_size=None,
204
+ num_workers=4,
205
+ shuffle=True,
206
+ pin_memory=True,
207
+ )
208
+ train_dl_iter = infinite_dl(train_dl)
209
+
210
+ if cfg.get("wandb"):
211
+ import wandb
212
+
213
+ wandb_run = wandb.init(
214
+ project=cfg.wandb.project,
215
+ name=f"{cfg.wandb.name}-{int(time.time())}",
216
+ config=OmegaConf.to_container(cfg), # type: ignore
217
+ )
218
+ print(f"🚀 Weights & Biases run URL: {wandb_run.get_url()}")
219
+
220
+ print("Loading model")
221
+ patch_model_fns = []
222
+ model_kwargs = {}
223
+ is_lora = cfg.model.type == "lora"
224
+ print(f"Training type: {'LoRA' if is_lora else 'Full'}")
225
+ if is_lora:
226
+ def mark_lora_params(m):
227
+ lora.mark_only_lora_as_trainable(m, bias="none")
228
+ return m
229
+
230
+ patch_model_fns.append(mark_lora_params)
231
+ model_kwargs = dict(**cfg.model.kwargs)
232
+ # Replace ListConfig with list to allow serialization to JSON.
233
+ for k, v in model_kwargs.items():
234
+ if isinstance(v, ListConfig):
235
+ model_kwargs[k] = list(v)
236
+
237
+ if cfg.training.get("model_dtype"):
238
+ assert cfg.training.model_dtype == "bf16", f"Only bf16 is supported"
239
+ patch_model_fns.append(lambda m: cast_dit(m, torch.bfloat16))
240
+
241
+ model = (
242
+ DitModelFactory(
243
+ model_path=str(checkpoint_path),
244
+ model_dtype="bf16",
245
+ attention_mode=cfg.attention_mode
246
+ ).get_model(
247
+ local_rank=0,
248
+ device_id=device_id,
249
+ model_kwargs=model_kwargs,
250
+ patch_model_fns=patch_model_fns,
251
+ world_size=1,
252
+ strict_load=not is_lora,
253
+ fast_init=not is_lora, # fast_init not supported for LoRA (please someone fix this !!!)
254
+ )
255
+ .train() # calling train() makes sure LoRA weights are not merged
256
+ )
257
+
258
+ optimizer = torch.optim.AdamW(model.parameters(), **cfg.optimizer)
259
+ if os.path.exists(opt_path):
260
+ print("Loading optimizer")
261
+ optimizer.load_state_dict(load_to_cpu(opt_path))
262
+
263
+ scheduler = get_cosine_annealing_lr_scheduler(
264
+ optimizer,
265
+ warmup_steps=cfg.training.warmup_steps,
266
+ total_steps=cfg.training.num_steps
267
+ )
268
+
269
+ print("Loading eval pipeline ...")
270
+ eval_pipeline = MochiTorchRunEvalPipeline(
271
+ device_id=device_id,
272
+ dit=model,
273
+ text_encoder_factory=T5ModelFactory(),
274
+ decoder_factory=DecoderModelFactory(model_path=cfg.sample.decoder_path),
275
+ )
276
+
277
+ def get_batch() -> Tuple[Dict[str, Any], Tensor, Tensor, Tensor]:
278
+ nonlocal train_dl_iter
279
+ batch = next(train_dl_iter) # type: ignore
280
+ latent, embed = cast(Tuple[Dict[str, Any], Dict[str, Any]], batch)
281
+ assert len(embed["y_feat"]) == 1 and len(embed["y_mask"]) == 1, f"Only batch size 1 is supported"
282
+
283
+ ldist = LatentDistribution(latent["mean"], latent["logvar"])
284
+ z = ldist.sample()
285
+ assert torch.isfinite(z).all()
286
+ assert z.shape[0] == 1, f"Only batch size 1 is supported"
287
+
288
+ eps = torch.randn_like(z)
289
+ sigma = torch.rand(z.shape[:1], device="cpu", dtype=torch.float32)
290
+
291
+ if random.random() < cfg.training.caption_dropout:
292
+ embed["y_mask"][0].zero_()
293
+ embed["y_feat"][0].zero_()
294
+ return embed, z, eps, sigma
295
+
296
+ pbar = tqdm(
297
+ range(start_step_num, cfg.training.num_steps),
298
+ total=cfg.training.num_steps,
299
+ initial=start_step_num,
300
+ )
301
+ for step in pbar:
302
+ if cfg.sample.interval and step % cfg.sample.interval == 0 and step > 0:
303
+ sample_dir = Path(cfg.sample.output_dir)
304
+ sample_dir.mkdir(exist_ok=True)
305
+ model.eval()
306
+ for eval_idx, prompt in enumerate(sample_prompts):
307
+ save_path = sample_dir / f"{eval_idx}_{step}.mp4"
308
+ if save_path.exists():
309
+ print(f"Skipping {save_path} as it already exists")
310
+ continue
311
+
312
+ sample_kwargs = {
313
+ k.removesuffix("_python_code"): (eval(v) if k.endswith("_python_code") else v)
314
+ for k, v in cfg.sample.kwargs.items()
315
+ }
316
+ eval_pipeline(
317
+ prompt=prompt,
318
+ save_path=str(save_path),
319
+ seed=cfg.sample.seed + eval_idx,
320
+ **sample_kwargs,
321
+ )
322
+ Path(sample_dir / f"{eval_idx}_{step}.txt").write_text(prompt)
323
+ model.train()
324
+
325
+ if cfg.training.save_interval and step > 0 and step % cfg.training.save_interval == 0:
326
+ with timer("get_state_dict"):
327
+ if is_lora:
328
+ model_sd = lora.lora_state_dict(model, bias="none")
329
+ else:
330
+ # NOTE: Not saving optimizer state dict to save space.
331
+ model_sd, _optimizer_sd = get_state_dict(
332
+ model, [], options=StateDictOptions(cpu_offload=True, full_state_dict=True)
333
+ )
334
+
335
+ checkpoint_filename = f"model_{step}.{'lora' if is_lora else 'checkpoint'}.pt"
336
+ save_path = checkpoint_dir / checkpoint_filename
337
+ if cfg.training.get("save_safetensors", True):
338
+ save_path = save_path.with_suffix(".safetensors")
339
+ save_file(
340
+ model_sd, save_path,
341
+ # `safetensors` only supports string-to-string metadata,
342
+ # so we serialize the kwargs to a JSON string.
343
+ metadata=dict(kwargs=json.dumps(model_kwargs)),
344
+ )
345
+ else:
346
+ torch.save(model_sd, save_path)
347
+
348
+ with torch.no_grad(), timer("load_batch", enabled=False):
349
+ batch = get_batch()
350
+ embed, z, eps, sigma = map_to_device(batch, device)
351
+ embed = cast(Dict[str, Any], embed)
352
+
353
+ num_latent_toks = np.prod(z.shape[-3:])
354
+ indices = compute_packed_indices(device, cast(Tensor, embed["y_mask"][0]), int(num_latent_toks))
355
+
356
+ sigma_bcthw = sigma[:, None, None, None, None] # [B, 1, 1, 1, 1]
357
+ z_sigma = (1 - sigma_bcthw) * z + sigma_bcthw * eps
358
+ ut = z - eps
359
+
360
+ with torch.autocast("cuda", dtype=torch.bfloat16):
361
+ preds = model(
362
+ x=z_sigma,
363
+ sigma=sigma,
364
+ packed_indices=indices,
365
+ **embed,
366
+ num_ff_checkpoint=cfg.training.num_ff_checkpoint,
367
+ num_qkv_checkpoint=cfg.training.num_qkv_checkpoint,
368
+ )
369
+ assert preds.shape == z.shape
370
+
371
+ loss = F.mse_loss(preds.float(), ut.float())
372
+ loss.backward()
373
+
374
+ log_kwargs = {
375
+ "train/loss": loss.item(),
376
+ "train/epoch": EPOCH_IDX,
377
+ "train/lr": scheduler.get_last_lr()[0],
378
+ }
379
+
380
+ if cfg.training.get("grad_clip"):
381
+ assert not is_lora, "Gradient clipping not supported for LoRA"
382
+ gnorm_before_clip = torch.nn.utils.clip_grad_norm_(
383
+ model.parameters(), max_norm=cfg.training.grad_clip)
384
+ log_kwargs["train/gnorm"] = gnorm_before_clip.item()
385
+ pbar.set_postfix(**log_kwargs)
386
+
387
+ if wandb_run:
388
+ wandb_run.log(log_kwargs, step=step)
389
+
390
+ optimizer.step()
391
+ scheduler.step()
392
+ optimizer.zero_grad()
393
+
394
+
395
+ if __name__ == "__main__":
396
+ main()
demos/fine_tuner/trim_and_crop_videos.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/env python3
2
+ from pathlib import Path
3
+ import shutil
4
+
5
+ import click
6
+ from moviepy.editor import VideoFileClip
7
+ from tqdm import tqdm
8
+
9
+
10
+ @click.command()
11
+ @click.argument("folder", type=click.Path(exists=True, dir_okay=True))
12
+ @click.argument("output_folder", type=click.Path(dir_okay=True))
13
+ @click.option("--duration", "-d", type=float, default=5.4, help="Duration in seconds")
14
+ @click.option("--resolution", "-r", type=str, default="848x480", help="Video resolution")
15
+ def truncate_videos(folder, output_folder, duration, resolution):
16
+ """Truncate all MP4 and MOV files in FOLDER to specified duration and resolution"""
17
+ input_path = Path(folder)
18
+ output_path = Path(output_folder)
19
+ output_path.mkdir(parents=True, exist_ok=True)
20
+
21
+ # Parse target resolution
22
+ target_width, target_height = map(int, resolution.split("x"))
23
+
24
+ # Find all MP4 and MOV files
25
+ video_files = (
26
+ list(input_path.rglob("*.mp4"))
27
+ + list(input_path.rglob("*.MOV"))
28
+ + list(input_path.rglob("*.mov"))
29
+ + list(input_path.rglob("*.MP4"))
30
+ )
31
+
32
+ for file_path in tqdm(video_files):
33
+ try:
34
+ relative_path = file_path.relative_to(input_path)
35
+ output_file = output_path / relative_path.with_suffix(".mp4")
36
+ output_file.parent.mkdir(parents=True, exist_ok=True)
37
+
38
+ click.echo(f"Processing: {file_path}")
39
+ video = VideoFileClip(str(file_path))
40
+
41
+ # Skip if video is too short
42
+ if video.duration < duration:
43
+ click.echo(f"Skipping {file_path} as it is too short")
44
+ continue
45
+
46
+ # Skip if target resolution is larger than input
47
+ if target_width > video.w or target_height > video.h:
48
+ click.echo(
49
+ f"Skipping {file_path} as target resolution {resolution} is larger than input {video.w}x{video.h}"
50
+ )
51
+ continue
52
+
53
+ # First truncate duration
54
+ truncated = video.subclip(0, duration)
55
+
56
+ # Calculate crop dimensions to maintain aspect ratio
57
+ target_ratio = target_width / target_height
58
+ current_ratio = truncated.w / truncated.h
59
+
60
+ if current_ratio > target_ratio:
61
+ # Video is wider than target ratio - crop width
62
+ new_width = int(truncated.h * target_ratio)
63
+ x1 = (truncated.w - new_width) // 2
64
+ final = truncated.crop(x1=x1, width=new_width).resize((target_width, target_height))
65
+ else:
66
+ # Video is taller than target ratio - crop height
67
+ new_height = int(truncated.w / target_ratio)
68
+ y1 = (truncated.h - new_height) // 2
69
+ final = truncated.crop(y1=y1, height=new_height).resize((target_width, target_height))
70
+
71
+ # Set output parameters for consistent MP4 encoding
72
+ output_params = {
73
+ "codec": "libx264",
74
+ "audio": False, # Disable audio
75
+ "preset": "medium", # Balance between speed and quality
76
+ "bitrate": "5000k", # Adjust as needed
77
+ }
78
+
79
+ # Set FPS to 30
80
+ final = final.set_fps(30)
81
+
82
+ # Check for a corresponding .txt file
83
+ txt_file_path = file_path.with_suffix('.txt')
84
+ if txt_file_path.exists():
85
+ output_txt_file = output_path / relative_path.with_suffix('.txt')
86
+ output_txt_file.parent.mkdir(parents=True, exist_ok=True)
87
+ shutil.copy(txt_file_path, output_txt_file)
88
+ click.echo(f"Copied {txt_file_path} to {output_txt_file}")
89
+ else:
90
+ # Print warning in bold yellow with a warning emoji
91
+ click.echo(f"\033[1;33m⚠️ Warning: No caption found for {file_path}, using an empty caption. This may hurt fine-tuning quality.\033[0m")
92
+ output_txt_file = output_path / relative_path.with_suffix('.txt')
93
+ output_txt_file.parent.mkdir(parents=True, exist_ok=True)
94
+ output_txt_file.touch()
95
+
96
+ # Write the output file
97
+ final.write_videofile(str(output_file), **output_params)
98
+
99
+ # Clean up
100
+ video.close()
101
+ truncated.close()
102
+ final.close()
103
+
104
+ except Exception as e:
105
+ click.echo(f"\033[1;31m Error processing {file_path}: {str(e)}\033[0m", err=True)
106
+ raise
107
+
108
+
109
+ if __name__ == "__main__":
110
+ truncate_videos()
demos/gradio_ui.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/env python
2
+
3
+
4
+ import sys
5
+
6
+ import click
7
+ import gradio as gr
8
+
9
+ sys.path.append("..")
10
+ from cli import configure_model, generate_video
11
+
12
+ with gr.Blocks() as demo:
13
+ gr.Markdown("Video Generator")
14
+ with gr.Row():
15
+ prompt = gr.Textbox(
16
+ label="Prompt",
17
+ value="A hand with delicate fingers picks up a bright yellow lemon from a wooden bowl filled with lemons and sprigs of mint against a peach-colored background. The hand gently tosses the lemon up and catches it, showcasing its smooth texture. A beige string bag sits beside the bowl, adding a rustic touch to the scene. Additional lemons, one halved, are scattered around the base of the bowl. The even lighting enhances the vibrant colors and creates a fresh, inviting atmosphere.",
18
+ )
19
+ negative_prompt = gr.Textbox(label="Negative Prompt", value="")
20
+ seed = gr.Number(label="Seed", value=1710977262, precision=0)
21
+ with gr.Row():
22
+ width = gr.Number(label="Width", value=848, precision=0)
23
+ height = gr.Number(label="Height", value=480, precision=0)
24
+ num_frames = gr.Number(label="Number of Frames", value=163, precision=0)
25
+ with gr.Row():
26
+ cfg_scale = gr.Number(label="CFG Scale", value=6.0)
27
+ num_inference_steps = gr.Number(label="Number of Inference Steps", value=100, precision=0)
28
+ btn = gr.Button("Generate Video")
29
+ output = gr.Video()
30
+
31
+ btn.click(
32
+ generate_video,
33
+ inputs=[
34
+ prompt,
35
+ negative_prompt,
36
+ width,
37
+ height,
38
+ num_frames,
39
+ seed,
40
+ cfg_scale,
41
+ num_inference_steps,
42
+ ],
43
+ outputs=output,
44
+ )
45
+
46
+
47
+ @click.command()
48
+ @click.option("--model_dir", required=True, help="Path to the model directory.")
49
+ @click.option("--lora_path", required=False, help="Path to the lora file.")
50
+ @click.option("--cpu_offload", is_flag=True, help="Whether to offload model to CPU")
51
+ def launch(model_dir, lora_path, cpu_offload):
52
+ configure_model(model_dir, lora_path, cpu_offload)
53
+ demo.launch()
54
+
55
+
56
+ if __name__ == "__main__":
57
+ launch()
demos/test_encoder_decoder.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import click
4
+ import torch
5
+ import torchvision
6
+ from einops import rearrange
7
+ from safetensors.torch import load_file
8
+
9
+ from genmo.lib.utils import save_video
10
+ from genmo.mochi_preview.pipelines import DecoderModelFactory, decode_latents_tiled_spatial
11
+ from genmo.mochi_preview.vae.models import Encoder, add_fourier_features
12
+
13
+
14
+ @click.command()
15
+ @click.argument("mochi_dir", type=str)
16
+ @click.argument("video_path", type=click.Path(exists=True))
17
+ def reconstruct(mochi_dir, video_path):
18
+ torch.backends.cuda.matmul.allow_tf32 = True
19
+ torch.backends.cudnn.allow_tf32 = True
20
+
21
+ decoder_factory = DecoderModelFactory(
22
+ model_path=f"{mochi_dir}/decoder.safetensors",
23
+ )
24
+ decoder = decoder_factory.get_model(world_size=1, device_id=0, local_rank=0)
25
+
26
+ config = dict(
27
+ prune_bottlenecks=[False, False, False, False, False],
28
+ has_attentions=[False, True, True, True, True],
29
+ affine=True,
30
+ bias=True,
31
+ input_is_conv_1x1=True,
32
+ padding_mode="replicate",
33
+ )
34
+
35
+ # Create VAE encoder
36
+ encoder = Encoder(
37
+ in_channels=15,
38
+ base_channels=64,
39
+ channel_multipliers=[1, 2, 4, 6],
40
+ num_res_blocks=[3, 3, 4, 6, 3],
41
+ latent_dim=12,
42
+ temporal_reductions=[1, 2, 3],
43
+ spatial_reductions=[2, 2, 2],
44
+ **config,
45
+ )
46
+ device = torch.device("cuda:0")
47
+ encoder = encoder.to(device, memory_format=torch.channels_last_3d)
48
+ encoder.load_state_dict(load_file(f"{mochi_dir}/encoder.safetensors"))
49
+ encoder.eval()
50
+
51
+ video, _, metadata = torchvision.io.read_video(video_path, output_format="THWC")
52
+ fps = metadata["video_fps"]
53
+ video = rearrange(video, "t h w c -> c t h w")
54
+ video = video.unsqueeze(0)
55
+ assert video.dtype == torch.uint8
56
+ # Convert to float in [-1, 1] range.
57
+ video = video.float() / 127.5 - 1.0
58
+ video = video.to(device)
59
+ video = add_fourier_features(video)
60
+ torch.cuda.synchronize()
61
+
62
+ # Encode video to latent
63
+ with torch.inference_mode():
64
+ with torch.autocast("cuda", dtype=torch.bfloat16):
65
+ t0 = time.time()
66
+ ldist = encoder(video)
67
+ torch.cuda.synchronize()
68
+ print(f"Time to encode: {time.time() - t0:.2f}s")
69
+ t0 = time.time()
70
+ frames = decode_latents_tiled_spatial(decoder, ldist.sample(), num_tiles_w=2, num_tiles_h=2)
71
+ torch.cuda.synchronize()
72
+ print(f"Time to decode: {time.time() - t0:.2f}s")
73
+ t0 = time.time()
74
+ save_video(frames.cpu().numpy()[0], f"{video_path}.recon.mp4", fps=fps)
75
+ print(f"Time to save: {time.time() - t0:.2f}s")
76
+
77
+
78
+ if __name__ == "__main__":
79
+ reconstruct()
encoder.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d3a8827a66b58a479d97420a9bf77e59078d88f538298469d8db28c37bd556ae
3
+ size 388912864
model_index.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "MochiPipeline",
3
+ "_diffusers_version": "0.32.0.dev0",
4
+ "scheduler": [
5
+ "diffusers",
6
+ "FlowMatchEulerDiscreteScheduler"
7
+ ],
8
+ "text_encoder": [
9
+ "transformers",
10
+ "T5EncoderModel"
11
+ ],
12
+ "tokenizer": [
13
+ "transformers",
14
+ "T5Tokenizer"
15
+ ],
16
+ "transformer": [
17
+ "diffusers",
18
+ "MochiTransformer3DModel"
19
+ ],
20
+ "vae": [
21
+ "diffusers",
22
+ "AutoencoderKLMochi"
23
+ ]
24
+ }
pusa_v0_dit.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0b59c4675886494861ec02ea3a11706815efa0e8909dc868109b0b25e79bfcb0
3
+ size 40110801256
pyproject.toml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "genmo"
3
+ version = "0.1.0"
4
+ description = "Genmo models"
5
+ readme = "README.md"
6
+ requires-python = ">=3.10"
7
+ dependencies = [
8
+ "addict>=2.4.0",
9
+ "av==13.1.0",
10
+ "click>=8.1.7",
11
+ "einops>=0.8.0",
12
+ "gradio>=3.36.1",
13
+ "moviepy==1.0.3",
14
+ "omegaconf>=2.3.0",
15
+ "pillow==9.5.0",
16
+ "pyyaml>=6.0.2",
17
+ "ray>=2.37.0",
18
+ "sentencepiece>=0.2.0",
19
+ "setuptools>=75.2.0",
20
+ "torch>=2.4.1",
21
+ "torchvision>=0.19.1",
22
+ "transformers>=4.45.2",
23
+ ]
24
+
25
+ [project.optional-dependencies]
26
+ flash = [
27
+ "flash-attn>=2.6.3"
28
+ ]
29
+
30
+ torchvision = [
31
+ "torchvision>=0.15.0",
32
+ "pyav>=13.1.0"
33
+ ]
34
+
35
+ [tool.ruff]
36
+ # Allow lines to be as long as 120.
37
+ line-length = 120
pyrightconfig.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "include": ["src/genmo/mochi_preview/pipelines.py"]
3
+ }
4
+
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ addict>=2.4.0
2
+ av==13.1.0
3
+ click>=8.1.7
4
+ einops>=0.8.0
5
+ gradio>=3.36.1
6
+ moviepy==1.0.3
7
+ omegaconf>=2.3.0
8
+ pillow==9.5.0
9
+ pyyaml>=6.0.2
10
+ ray>=2.37.0
11
+ sentencepiece>=0.2.0
12
+ setuptools>=75.2.0
13
+ torch>=2.4.1
14
+ transformers>=4.45.2
scheduler/scheduler_config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "FlowMatchEulerDiscreteScheduler",
3
+ "_diffusers_version": "0.32.0.dev0",
4
+ "base_image_seq_len": 256,
5
+ "base_shift": 0.5,
6
+ "invert_sigmas": true,
7
+ "max_image_seq_len": 4096,
8
+ "max_shift": 1.15,
9
+ "num_train_timesteps": 1000,
10
+ "shift": 1.0,
11
+ "use_dynamic_shifting": false
12
+ }
scripts/download_weights.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/env python3
2
+ import os
3
+
4
+ import click
5
+ from huggingface_hub import snapshot_download
6
+
7
+
8
+ # Based off of Kijai's script
9
+ @click.command()
10
+ @click.argument('output_dir', required=True)
11
+ def download_weights(output_dir):
12
+ repo_id = "genmo/mochi-1-preview"
13
+ model = "dit.safetensors"
14
+ decoder = "decoder.safetensors"
15
+ encoder = "encoder.safetensors"
16
+
17
+ if not os.path.exists(output_dir):
18
+ print(f"Creating output directory: {output_dir}")
19
+ os.makedirs(output_dir, exist_ok=True)
20
+
21
+ def download_file(repo_id, output_dir, filename, description):
22
+ file_path = os.path.join(output_dir, filename)
23
+ if not os.path.exists(file_path):
24
+ print(f"Downloading mochi {description} to: {file_path}")
25
+ snapshot_download(
26
+ repo_id=repo_id,
27
+ allow_patterns=[f"*{filename}*"],
28
+ local_dir=output_dir,
29
+ local_dir_use_symlinks=False,
30
+ )
31
+ else:
32
+ print(f"{description} already exists in: {file_path}")
33
+ assert os.path.exists(file_path)
34
+
35
+ download_file(repo_id, output_dir, decoder, "decoder")
36
+ download_file(repo_id, output_dir, encoder, "encoder")
37
+ download_file(repo_id, output_dir, model, "model")
38
+
39
+
40
+ if __name__ == "__main__":
41
+ download_weights()
scripts/format.bash ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #! /bin/bash
2
+ set -euxo pipefail
3
+ ruff format src demos
4
+ ruff check --fix --select I src
5
+ ruff check --fix --select I demos
scripts/pytorch_to_safe_tensors.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/env python3
2
+ from pathlib import Path
3
+
4
+ import click
5
+ import torch
6
+ from safetensors.torch import save_file
7
+
8
+
9
+ @click.command()
10
+ @click.argument("input_path", type=click.Path(exists=True))
11
+ def convert_to_safetensors(input_path):
12
+ model = torch.load(input_path)
13
+ model = {
14
+ k: v.contiguous() for k, v in model.items()
15
+ }
16
+ assert 'vae_ema' not in model
17
+ input_path = Path(input_path)
18
+ output_path = input_path.with_suffix(".safetensors")
19
+ save_file(model, str(output_path))
20
+ click.echo(f"Converted {input_path} to {output_path}")
21
+
22
+
23
+ if __name__ == "__main__":
24
+ convert_to_safetensors()
scripts/typecheck.bash ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ #! /bin/bash
2
+ npx pyright
scripts/weights_to_fp8.py ADDED
File without changes
src/genmo/lib/attn_imports.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import contextmanager
2
+
3
+ import torch
4
+
5
+
6
+ try:
7
+ from flash_attn import flash_attn_varlen_func as flash_varlen_attn
8
+ except ImportError:
9
+ flash_varlen_attn = None
10
+
11
+ try:
12
+ from sageattention import sageattn as sage_attn
13
+ except ImportError:
14
+ sage_attn = None
15
+
16
+ from torch.nn.attention import SDPBackend, sdpa_kernel
17
+
18
+ training_backends = [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]
19
+ eval_backends = list(training_backends)
20
+ if torch.cuda.get_device_properties(0).major >= 9.0:
21
+ # Enable fast CuDNN attention on Hopper.
22
+ # This gives NaN on the backward pass for some reason,
23
+ # so only use it for evaluation.
24
+ eval_backends.append(SDPBackend.CUDNN_ATTENTION)
25
+
26
+ @contextmanager
27
+ def sdpa_attn_ctx(training: bool = False):
28
+ with sdpa_kernel(training_backends if training else eval_backends):
29
+ yield
src/genmo/lib/progress.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ from typing import Any, Iterable, Iterator, Optional
3
+
4
+ try:
5
+ from tqdm import tqdm
6
+ except ImportError:
7
+ tqdm = None
8
+
9
+ try:
10
+ from ray.experimental.tqdm_ray import tqdm as ray_tqdm
11
+ except:
12
+ ray_tqdm = None
13
+
14
+ # Global state
15
+ _current_progress_type = "none"
16
+ _is_progress_bar_active = False
17
+
18
+
19
+ class DummyProgressBar:
20
+ """A no-op progress bar that mimics tqdm interface"""
21
+
22
+ def __init__(self, iterable=None, **kwargs):
23
+ self.iterable = iterable
24
+
25
+ def __iter__(self):
26
+ return iter(self.iterable)
27
+
28
+ def update(self, n=1):
29
+ pass
30
+
31
+ def close(self):
32
+ pass
33
+
34
+ def set_description(self, desc):
35
+ pass
36
+
37
+
38
+ def get_new_progress_bar(iterable: Optional[Iterable] = None, **kwargs) -> Any:
39
+ if not _is_progress_bar_active:
40
+ return DummyProgressBar(iterable=iterable, **kwargs)
41
+
42
+ if _current_progress_type == "tqdm":
43
+ if tqdm is None:
44
+ raise ImportError("tqdm is required but not installed. Please install tqdm to use the tqdm progress bar.")
45
+ return tqdm(iterable=iterable, **kwargs)
46
+ elif _current_progress_type == "ray_tqdm":
47
+ if ray_tqdm is None:
48
+ raise ImportError("ray is required but not installed. Please install ray to use the ray_tqdm progress bar.")
49
+ return ray_tqdm(iterable=iterable, **kwargs)
50
+ return DummyProgressBar(iterable=iterable, **kwargs)
51
+
52
+
53
+ @contextlib.contextmanager
54
+ def progress_bar(type: str = "none", enabled=True):
55
+ """
56
+ Context manager for setting progress bar type and options.
57
+
58
+ Args:
59
+ type: Type of progress bar ("none" or "tqdm")
60
+ **options: Options to pass to the progress bar (e.g., total, desc)
61
+
62
+ Raises:
63
+ ValueError: If progress bar type is invalid
64
+ RuntimeError: If progress bars are nested
65
+
66
+ Example:
67
+ with progress_bar(type="tqdm", total=100):
68
+ for i in get_new_progress_bar(range(100)):
69
+ process(i)
70
+ """
71
+ if type not in ("none", "tqdm", "ray_tqdm"):
72
+ raise ValueError("Progress bar type must be 'none' or 'tqdm' or 'ray_tqdm'")
73
+ if not enabled:
74
+ type = "none"
75
+ global _current_progress_type, _is_progress_bar_active
76
+
77
+ if _is_progress_bar_active:
78
+ raise RuntimeError("Nested progress bars are not supported")
79
+
80
+ _is_progress_bar_active = True
81
+ _current_progress_type = type
82
+
83
+ try:
84
+ yield
85
+ finally:
86
+ _is_progress_bar_active = False
87
+ _current_progress_type = "none"
src/genmo/lib/utils.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import tempfile
4
+ import time
5
+
6
+ import numpy as np
7
+ from moviepy.editor import ImageSequenceClip
8
+ from PIL import Image
9
+
10
+ from genmo.lib.progress import get_new_progress_bar
11
+
12
+
13
+ class Timer:
14
+ def __init__(self):
15
+ self.times = {} # Dictionary to store times per stage
16
+
17
+ def __call__(self, name):
18
+ print(f"Timing {name}")
19
+ return self.TimerContextManager(self, name)
20
+
21
+ def print_stats(self):
22
+ total_time = sum(self.times.values())
23
+ # Print table header
24
+ print("{:<20} {:>10} {:>10}".format("Stage", "Time(s)", "Percent"))
25
+ for name, t in self.times.items():
26
+ percent = (t / total_time) * 100 if total_time > 0 else 0
27
+ print("{:<20} {:>10.2f} {:>9.2f}%".format(name, t, percent))
28
+
29
+ class TimerContextManager:
30
+ def __init__(self, outer, name):
31
+ self.outer = outer # Reference to the Timer instance
32
+ self.name = name
33
+ self.start_time = None
34
+
35
+ def __enter__(self):
36
+ self.start_time = time.perf_counter()
37
+ return self
38
+
39
+ def __exit__(self, exc_type, exc_value, traceback):
40
+ end_time = time.perf_counter()
41
+ elapsed = end_time - self.start_time
42
+ self.outer.times[self.name] = self.outer.times.get(self.name, 0) + elapsed
43
+
44
+
45
+ def save_video(final_frames, output_path, fps=30):
46
+ assert final_frames.ndim == 4 and final_frames.shape[3] == 3, f"invalid shape: {final_frames} (need t h w c)"
47
+ if final_frames.dtype != np.uint8:
48
+ final_frames = (final_frames * 255).astype(np.uint8)
49
+ ImageSequenceClip(list(final_frames), fps=fps).write_videofile(output_path)
50
+
51
+
52
+ def create_memory_tracker():
53
+ import torch
54
+
55
+ previous = [None] # Use list for mutable closure state
56
+
57
+ def track(label="all2all"):
58
+ current = torch.cuda.memory_allocated() / 1e9
59
+ if previous[0] is not None:
60
+ diff = current - previous[0]
61
+ sign = "+" if diff >= 0 else ""
62
+ print(f"GPU memory ({label}): {current:.2f} GB ({sign}{diff:.2f} GB)")
63
+ else:
64
+ print(f"GPU memory ({label}): {current:.2f} GB")
65
+ previous[0] = current # type: ignore
66
+
67
+ return track
src/genmo/mochi_preview/__init__.py ADDED
File without changes
src/genmo/mochi_preview/dit/joint_model/__init__.py ADDED
File without changes
src/genmo/mochi_preview/dit/joint_model/asymm_models_joint.py ADDED
@@ -0,0 +1,737 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict, List, Optional, Tuple
3
+ import warnings
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+ from torch.nn.attention import sdpa_kernel
10
+
11
+ import genmo.mochi_preview.dit.joint_model.context_parallel as cp
12
+ from genmo.lib.attn_imports import flash_varlen_attn, sage_attn, sdpa_attn_ctx
13
+ from genmo.mochi_preview.dit.joint_model.layers import (
14
+ FeedForward,
15
+ PatchEmbed,
16
+ RMSNorm,
17
+ TimestepEmbedder,
18
+ )
19
+ from genmo.mochi_preview.dit.joint_model.lora import LoraLinear
20
+ from genmo.mochi_preview.dit.joint_model.mod_rmsnorm import modulated_rmsnorm
21
+ from genmo.mochi_preview.dit.joint_model.residual_tanh_gated_rmsnorm import (
22
+ residual_tanh_gated_rmsnorm,
23
+ )
24
+ from genmo.mochi_preview.dit.joint_model.rope_mixed import (
25
+ compute_mixed_rotation,
26
+ create_position_matrix,
27
+ )
28
+ from genmo.mochi_preview.dit.joint_model.temporal_rope import apply_rotary_emb_qk_real
29
+ from genmo.mochi_preview.dit.joint_model.utils import (
30
+ AttentionPool,
31
+ modulate,
32
+ pad_and_split_xy,
33
+ )
34
+
35
+ COMPILE_FINAL_LAYER = os.environ.get("COMPILE_DIT") == "1"
36
+ COMPILE_MMDIT_BLOCK = os.environ.get("COMPILE_DIT") == "1"
37
+
38
+
39
+ def ck(fn, *args, enabled=True, **kwargs) -> torch.Tensor:
40
+ if enabled:
41
+ return torch.utils.checkpoint.checkpoint(fn, *args, **kwargs, use_reentrant=False)
42
+
43
+ return fn(*args, **kwargs)
44
+
45
+
46
+ class AsymmetricAttention(nn.Module):
47
+ def __init__(
48
+ self,
49
+ dim_x: int,
50
+ dim_y: int,
51
+ num_heads: int = 8,
52
+ qkv_bias: bool = True,
53
+ qk_norm: bool = False,
54
+ update_y: bool = True,
55
+ out_bias: bool = True,
56
+ attention_mode: str = "flash",
57
+ softmax_scale: Optional[float] = None,
58
+ device: Optional[torch.device] = None,
59
+ # Disable LoRA by default ...
60
+ qkv_proj_lora_rank: int = 0,
61
+ qkv_proj_lora_alpha: int = 0,
62
+ qkv_proj_lora_dropout: float = 0.0,
63
+ out_proj_lora_rank: int = 0,
64
+ out_proj_lora_alpha: int = 0,
65
+ out_proj_lora_dropout: float = 0.0,
66
+ ):
67
+ super().__init__()
68
+ self.attention_mode = attention_mode
69
+ self.dim_x = dim_x
70
+ self.dim_y = dim_y
71
+ self.num_heads = num_heads
72
+ self.head_dim = dim_x // num_heads
73
+ self.update_y = update_y
74
+ self.softmax_scale = softmax_scale
75
+ if dim_x % num_heads != 0:
76
+ raise ValueError(f"dim_x={dim_x} should be divisible by num_heads={num_heads}")
77
+
78
+ # Input layers.
79
+ self.qkv_bias = qkv_bias
80
+ qkv_lora_kwargs = dict(
81
+ bias=qkv_bias,
82
+ device=device,
83
+ r=qkv_proj_lora_rank,
84
+ lora_alpha=qkv_proj_lora_alpha,
85
+ lora_dropout=qkv_proj_lora_dropout,
86
+ )
87
+ self.qkv_x = LoraLinear(dim_x, 3 * dim_x, **qkv_lora_kwargs)
88
+ # Project text features to match visual features (dim_y -> dim_x)
89
+ self.qkv_y = LoraLinear(dim_y, 3 * dim_x, **qkv_lora_kwargs)
90
+
91
+ # Query and key normalization for stability.
92
+ assert qk_norm
93
+ self.q_norm_x = RMSNorm(self.head_dim, device=device)
94
+ self.k_norm_x = RMSNorm(self.head_dim, device=device)
95
+ self.q_norm_y = RMSNorm(self.head_dim, device=device)
96
+ self.k_norm_y = RMSNorm(self.head_dim, device=device)
97
+
98
+ # Output layers. y features go back down from dim_x -> dim_y.
99
+ proj_lora_kwargs = dict(
100
+ bias=out_bias,
101
+ device=device,
102
+ r=out_proj_lora_rank,
103
+ lora_alpha=out_proj_lora_alpha,
104
+ lora_dropout=out_proj_lora_dropout,
105
+ )
106
+ self.proj_x = LoraLinear(dim_x, dim_x, **proj_lora_kwargs)
107
+ self.proj_y = LoraLinear(dim_x, dim_y, **proj_lora_kwargs) if update_y else nn.Identity()
108
+
109
+ def run_qkv_y(self, y):
110
+ cp_rank, cp_size = cp.get_cp_rank_size()
111
+ local_heads = self.num_heads // cp_size
112
+
113
+ if cp.is_cp_active():
114
+ # Only predict local heads.
115
+ assert not self.qkv_bias
116
+ W_qkv_y = self.qkv_y.weight.view(3, self.num_heads, self.head_dim, self.dim_y)
117
+ W_qkv_y = W_qkv_y.narrow(1, cp_rank * local_heads, local_heads)
118
+ W_qkv_y = W_qkv_y.reshape(3 * local_heads * self.head_dim, self.dim_y)
119
+ qkv_y = F.linear(y, W_qkv_y, None) # (B, L, 3 * local_h * head_dim)
120
+ else:
121
+ qkv_y = self.qkv_y(y) # (B, L, 3 * dim)
122
+
123
+ qkv_y = qkv_y.view(qkv_y.size(0), qkv_y.size(1), 3, local_heads, self.head_dim)
124
+ q_y, k_y, v_y = qkv_y.unbind(2)
125
+
126
+ q_y = self.q_norm_y(q_y)
127
+ k_y = self.k_norm_y(k_y)
128
+ return q_y, k_y, v_y
129
+
130
+ def prepare_qkv(
131
+ self,
132
+ x: torch.Tensor, # (B, M, dim_x)
133
+ y: torch.Tensor, # (B, L, dim_y)
134
+ *,
135
+ scale_x: torch.Tensor,
136
+ scale_y: torch.Tensor,
137
+ rope_cos: torch.Tensor,
138
+ rope_sin: torch.Tensor,
139
+ valid_token_indices: torch.Tensor,
140
+ max_seqlen_in_batch: int,
141
+ ):
142
+ # Process visual features
143
+ x = modulated_rmsnorm(x, scale_x) # (B, M, dim_x) where M = N / cp_group_size
144
+ qkv_x = self.qkv_x(x) # (B, M, 3 * dim_x)
145
+ assert qkv_x.dtype == torch.bfloat16
146
+
147
+ qkv_x = cp.all_to_all_collect_tokens(qkv_x, self.num_heads) # (3, B, N, local_h, head_dim)
148
+
149
+ # Split qkv_x into q, k, v
150
+ q_x, k_x, v_x = qkv_x.unbind(0) # (B, N, local_h, head_dim)
151
+ q_x = self.q_norm_x(q_x)
152
+ q_x = apply_rotary_emb_qk_real(q_x, rope_cos, rope_sin)
153
+ k_x = self.k_norm_x(k_x)
154
+ k_x = apply_rotary_emb_qk_real(k_x, rope_cos, rope_sin)
155
+
156
+ # Concatenate streams
157
+ B, N, num_heads, head_dim = q_x.size()
158
+ D = num_heads * head_dim
159
+
160
+ # Process text features
161
+ if B == 1:
162
+ text_seqlen = max_seqlen_in_batch - N
163
+ if text_seqlen > 0:
164
+ y = y[:, :text_seqlen] # Remove padding tokens.
165
+ y = modulated_rmsnorm(y, scale_y) # (B, L, dim_y)
166
+ q_y, k_y, v_y = self.run_qkv_y(y) # (B, L, local_heads, head_dim)
167
+
168
+ q = torch.cat([q_x, q_y], dim=1)
169
+ k = torch.cat([k_x, k_y], dim=1)
170
+ v = torch.cat([v_x, v_y], dim=1)
171
+ else:
172
+ q, k, v = q_x, k_x, v_x
173
+ else:
174
+ y = modulated_rmsnorm(y, scale_y) # (B, L, dim_y)
175
+ q_y, k_y, v_y = self.run_qkv_y(y) # (B, L, local_heads, head_dim)
176
+
177
+ indices = valid_token_indices[:, None].expand(-1, D)
178
+ q = torch.cat([q_x, q_y], dim=1).view(-1, D).gather(0, indices) # (total, D)
179
+ k = torch.cat([k_x, k_y], dim=1).view(-1, D).gather(0, indices) # (total, D)
180
+ v = torch.cat([v_x, v_y], dim=1).view(-1, D).gather(0, indices) # (total, D)
181
+
182
+ q = q.view(-1, num_heads, head_dim)
183
+ k = k.view(-1, num_heads, head_dim)
184
+ v = v.view(-1, num_heads, head_dim)
185
+ return q, k, v
186
+
187
+ @torch.autocast("cuda", enabled=False)
188
+ def flash_attention(self, q, k, v, cu_seqlens, max_seqlen_in_batch, total, local_dim):
189
+ out: torch.Tensor = flash_varlen_attn(
190
+ q, k, v,
191
+ cu_seqlens_q=cu_seqlens,
192
+ cu_seqlens_k=cu_seqlens,
193
+ max_seqlen_q=max_seqlen_in_batch,
194
+ max_seqlen_k=max_seqlen_in_batch,
195
+ dropout_p=0.0,
196
+ softmax_scale=self.softmax_scale,
197
+ ) # (total, local_heads, head_dim)
198
+ return out.view(total, local_dim)
199
+
200
+ def sdpa_attention(self, q, k, v):
201
+ with sdpa_attn_ctx(training=self.training):
202
+ out = F.scaled_dot_product_attention(
203
+ q, k, v,
204
+ attn_mask=None,
205
+ dropout_p=0.0,
206
+ is_causal=False,
207
+ )
208
+ return out
209
+
210
+ @torch.autocast("cuda", enabled=False)
211
+ def sage_attention(self, q, k, v):
212
+ return sage_attn(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
213
+
214
+ def run_attention(
215
+ self,
216
+ q: torch.Tensor, # (total <= B * (N + L), num_heads, head_dim)
217
+ k: torch.Tensor, # (total <= B * (N + L), num_heads, head_dim)
218
+ v: torch.Tensor, # (total <= B * (N + L), num_heads, head_dim)
219
+ *,
220
+ B: int,
221
+ cu_seqlens: Optional[torch.Tensor] = None,
222
+ max_seqlen_in_batch: Optional[int] = None,
223
+ ):
224
+ _, cp_size = cp.get_cp_rank_size()
225
+ assert self.num_heads % cp_size == 0
226
+ local_heads = self.num_heads // cp_size
227
+ local_dim = local_heads * self.head_dim
228
+
229
+ # Check shapes
230
+ assert q.ndim == 3 and k.ndim == 3 and v.ndim == 3
231
+ total = q.size(0)
232
+ assert k.size(0) == total and v.size(0) == total
233
+
234
+ if self.attention_mode == "flash":
235
+ out = self.flash_attention(
236
+ q, k, v, cu_seqlens, max_seqlen_in_batch, total, local_dim) # (total, local_dim)
237
+ else:
238
+ assert B == 1, \
239
+ f"Non-flash attention mode {self.attention_mode} only supports batch size 1, got {B}"
240
+
241
+ q = rearrange(q, "(b s) h d -> b h s d", b=B)
242
+ k = rearrange(k, "(b s) h d -> b h s d", b=B)
243
+ v = rearrange(v, "(b s) h d -> b h s d", b=B)
244
+
245
+ if self.attention_mode == "sdpa":
246
+ out = self.sdpa_attention(q, k, v) # (B, local_heads, seq_len, head_dim)
247
+ elif self.attention_mode == "sage":
248
+ out = self.sage_attention(q, k, v) # (B, local_heads, seq_len, head_dim)
249
+ else:
250
+ raise ValueError(f"Unknown attention mode: {self.attention_mode}")
251
+
252
+ out = rearrange(out, "b h s d -> (b s) (h d)")
253
+
254
+ return out
255
+
256
+ def post_attention(
257
+ self,
258
+ out: torch.Tensor,
259
+ B: int,
260
+ M: int,
261
+ L: int,
262
+ dtype: torch.dtype,
263
+ valid_token_indices: torch.Tensor,
264
+ ):
265
+ """
266
+ Args:
267
+ out: (total <= B * (N + L), local_dim)
268
+ valid_token_indices: (total <= B * (N + L),)
269
+ B: Batch size
270
+ M: Number of visual tokens per context parallel rank
271
+ L: Number of text tokens
272
+ dtype: Data type of the input and output tensors
273
+
274
+ Returns:
275
+ x: (B, N, dim_x) tensor of visual tokens where N = M * cp_size
276
+ y: (B, L, dim_y) tensor of text token features
277
+ """
278
+ _, cp_size = cp.get_cp_rank_size()
279
+ local_heads = self.num_heads // cp_size
280
+ local_dim = local_heads * self.head_dim
281
+ N = M * cp_size
282
+
283
+ # Split sequence into visual and text tokens, adding back padding.
284
+ if B == 1:
285
+ out = out.view(B, -1, local_dim)
286
+ if out.size(1) > N:
287
+ x, y = torch.tensor_split(out, (N,), dim=1) # (B, N, local_dim), (B, <= L, local_dim)
288
+ y = F.pad(y, (0, 0, 0, L - y.size(1))) # (B, L, local_dim)
289
+ else:
290
+ # Empty prompt.
291
+ x, y = out, out.new_zeros(B, L, local_dim)
292
+ else:
293
+ x, y = pad_and_split_xy(out, valid_token_indices, B, N, L, dtype)
294
+ assert x.size() == (B, N, local_dim)
295
+ assert y.size() == (B, L, local_dim)
296
+
297
+ # Communicate across context parallel ranks.
298
+ x = x.view(B, N, local_heads, self.head_dim)
299
+ x = cp.all_to_all_collect_heads(x) # (B, M, dim_x = num_heads * head_dim)
300
+ if cp.is_cp_active():
301
+ y = cp.all_gather(y) # (cp_size * B, L, local_heads * head_dim)
302
+ y = rearrange(y, "(G B) L D -> B L (G D)", G=cp_size, D=local_dim) # (B, L, dim_x)
303
+
304
+ x = self.proj_x(x)
305
+ y = self.proj_y(y)
306
+ return x, y
307
+
308
+ def forward(
309
+ self,
310
+ x: torch.Tensor, # (B, M, dim_x)
311
+ y: torch.Tensor, # (B, L, dim_y)
312
+ *,
313
+ scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm.
314
+ scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm.
315
+ packed_indices: Dict[str, torch.Tensor] = None,
316
+ checkpoint_qkv: bool = False,
317
+ checkpoint_post_attn: bool = False,
318
+ **rope_rotation,
319
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
320
+ """Forward pass of asymmetric multi-modal attention.
321
+
322
+ Args:
323
+ x: (B, M, dim_x) tensor of visual tokens
324
+ y: (B, L, dim_y) tensor of text token features
325
+ packed_indices: Dict with keys for Flash Attention
326
+ num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens
327
+
328
+ Returns:
329
+ x: (B, M, dim_x) tensor of visual tokens after multi-modal attention
330
+ y: (B, L, dim_y) tensor of text token features after multi-modal attention
331
+ """
332
+ B, L, _ = y.shape
333
+ _, M, _ = x.shape
334
+
335
+ # Predict a packed QKV tensor from visual and text features.
336
+ q, k, v = ck(self.prepare_qkv,
337
+ x=x,
338
+ y=y,
339
+ scale_x=scale_x,
340
+ scale_y=scale_y,
341
+ rope_cos=rope_rotation.get("rope_cos"),
342
+ rope_sin=rope_rotation.get("rope_sin"),
343
+ valid_token_indices=packed_indices["valid_token_indices_kv"],
344
+ max_seqlen_in_batch=packed_indices["max_seqlen_in_batch_kv"],
345
+ enabled=checkpoint_qkv,
346
+ ) # (total <= B * (N + L), 3, local_heads, head_dim)
347
+
348
+ # Self-attention is expensive, so don't checkpoint it.
349
+ out = self.run_attention(
350
+ q, k, v, B=B,
351
+ cu_seqlens=packed_indices["cu_seqlens_kv"],
352
+ max_seqlen_in_batch=packed_indices["max_seqlen_in_batch_kv"],
353
+ )
354
+
355
+ x, y = ck(self.post_attention,
356
+ out,
357
+ B=B, M=M, L=L,
358
+ dtype=v.dtype,
359
+ valid_token_indices=packed_indices["valid_token_indices_kv"],
360
+ enabled=checkpoint_post_attn,
361
+ )
362
+
363
+ return x, y
364
+
365
+
366
+ @torch.compile(disable=not COMPILE_MMDIT_BLOCK)
367
+ class AsymmetricJointBlock(nn.Module):
368
+ def __init__(
369
+ self,
370
+ hidden_size_x: int,
371
+ hidden_size_y: int,
372
+ num_heads: int,
373
+ *,
374
+ mlp_ratio_x: float = 8.0, # Ratio of hidden size to d_model for MLP for visual tokens.
375
+ mlp_ratio_y: float = 4.0, # Ratio of hidden size to d_model for MLP for text tokens.
376
+ update_y: bool = True, # Whether to update text tokens in this block.
377
+ device: Optional[torch.device] = None,
378
+ **block_kwargs,
379
+ ):
380
+ super().__init__()
381
+ self.update_y = update_y
382
+ self.hidden_size_x = hidden_size_x
383
+ self.hidden_size_y = hidden_size_y
384
+ self.mod_x = nn.Linear(hidden_size_x, 4 * hidden_size_x, device=device)
385
+ if self.update_y:
386
+ self.mod_y = nn.Linear(hidden_size_x, 4 * hidden_size_y, device=device)
387
+ else:
388
+ self.mod_y = nn.Linear(hidden_size_x, hidden_size_y, device=device)
389
+
390
+ # Self-attention:
391
+ self.attn = AsymmetricAttention(
392
+ hidden_size_x,
393
+ hidden_size_y,
394
+ num_heads=num_heads,
395
+ update_y=update_y,
396
+ device=device,
397
+ **block_kwargs,
398
+ )
399
+
400
+ # MLP.
401
+ mlp_hidden_dim_x = int(hidden_size_x * mlp_ratio_x)
402
+ assert mlp_hidden_dim_x == int(1536 * 8)
403
+ self.mlp_x = FeedForward(
404
+ in_features=hidden_size_x,
405
+ hidden_size=mlp_hidden_dim_x,
406
+ multiple_of=256,
407
+ ffn_dim_multiplier=None,
408
+ device=device,
409
+ )
410
+
411
+ # MLP for text not needed in last block.
412
+ if self.update_y:
413
+ mlp_hidden_dim_y = int(hidden_size_y * mlp_ratio_y)
414
+ self.mlp_y = FeedForward(
415
+ in_features=hidden_size_y,
416
+ hidden_size=mlp_hidden_dim_y,
417
+ multiple_of=256,
418
+ ffn_dim_multiplier=None,
419
+ device=device,
420
+ )
421
+
422
+ def forward(
423
+ self,
424
+ x: torch.Tensor,
425
+ c: torch.Tensor,
426
+ y: torch.Tensor,
427
+ # TODO: These could probably just go into attn_kwargs
428
+ checkpoint_ff: bool = False,
429
+ checkpoint_qkv: bool = False,
430
+ checkpoint_post_attn: bool = False,
431
+ **attn_kwargs,
432
+ ):
433
+ """Forward pass of a block.
434
+
435
+ Args:
436
+ x: (B, N, dim) tensor of visual tokens
437
+ c: (B, dim) tensor of conditioned features
438
+ y: (B, L, dim) tensor of text tokens
439
+ num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens
440
+
441
+ Returns:
442
+ x: (B, N, dim) tensor of visual tokens after block
443
+ y: (B, L, dim) tensor of text tokens after block
444
+ """
445
+ N = x.size(1)
446
+
447
+ c = F.silu(c)
448
+ mod_x = self.mod_x(c)
449
+ scale_msa_x, gate_msa_x, scale_mlp_x, gate_mlp_x = mod_x.chunk(4, dim=1)
450
+ mod_y = self.mod_y(c)
451
+
452
+ if self.update_y:
453
+ scale_msa_y, gate_msa_y, scale_mlp_y, gate_mlp_y = mod_y.chunk(4, dim=1)
454
+ else:
455
+ scale_msa_y = mod_y
456
+
457
+ # Self-attention block.
458
+ x_attn, y_attn = self.attn(
459
+ x,
460
+ y,
461
+ scale_x=scale_msa_x,
462
+ scale_y=scale_msa_y,
463
+ checkpoint_qkv=checkpoint_qkv,
464
+ checkpoint_post_attn=checkpoint_post_attn,
465
+ **attn_kwargs,
466
+ )
467
+
468
+ assert x_attn.size(1) == N
469
+ x = residual_tanh_gated_rmsnorm(x, x_attn, gate_msa_x)
470
+
471
+ if self.update_y:
472
+ y = residual_tanh_gated_rmsnorm(y, y_attn, gate_msa_y)
473
+
474
+ # MLP block.
475
+ x = ck(self.ff_block_x, x, scale_mlp_x, gate_mlp_x, enabled=checkpoint_ff)
476
+ if self.update_y:
477
+ y = ck(self.ff_block_y, y, scale_mlp_y, gate_mlp_y, enabled=checkpoint_ff) # type: ignore
478
+ return x, y
479
+
480
+ def ff_block_x(self, x, scale_x, gate_x):
481
+ x_mod = modulated_rmsnorm(x, scale_x)
482
+ x_res = self.mlp_x(x_mod)
483
+ x = residual_tanh_gated_rmsnorm(x, x_res, gate_x) # Sandwich norm
484
+ return x
485
+
486
+ def ff_block_y(self, y, scale_y, gate_y):
487
+ y_mod = modulated_rmsnorm(y, scale_y)
488
+ y_res = self.mlp_y(y_mod)
489
+ y = residual_tanh_gated_rmsnorm(y, y_res, gate_y) # Sandwich norm
490
+ return y
491
+
492
+
493
+ @torch.compile(disable=not COMPILE_FINAL_LAYER)
494
+ class FinalLayer(nn.Module):
495
+ """
496
+ The final layer of DiT.
497
+ """
498
+
499
+ def __init__(
500
+ self,
501
+ hidden_size,
502
+ patch_size,
503
+ out_channels,
504
+ device: Optional[torch.device] = None,
505
+ ):
506
+ super().__init__()
507
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, device=device)
508
+ self.mod = nn.Linear(hidden_size, 2 * hidden_size, device=device)
509
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, device=device)
510
+
511
+ def forward(self, x, c):
512
+ c = F.silu(c)
513
+ shift, scale = self.mod(c).chunk(2, dim=1)
514
+ x = modulate(self.norm_final(x), shift, scale)
515
+ x = self.linear(x)
516
+ return x
517
+
518
+
519
+ class AsymmDiTJoint(nn.Module):
520
+ """
521
+ Diffusion model with a Transformer backbone.
522
+
523
+ Ingests text embeddings instead of a label.
524
+ """
525
+
526
+ def __init__(
527
+ self,
528
+ *,
529
+ patch_size=2,
530
+ in_channels=4,
531
+ hidden_size_x=1152,
532
+ hidden_size_y=1152,
533
+ depth=48,
534
+ num_heads=16,
535
+ mlp_ratio_x=8.0,
536
+ mlp_ratio_y=4.0,
537
+ t5_feat_dim: int = 4096,
538
+ t5_token_length: int = 256,
539
+ patch_embed_bias: bool = True,
540
+ timestep_mlp_bias: bool = True,
541
+ timestep_scale: Optional[float] = None,
542
+ use_extended_posenc: bool = False,
543
+ rope_theta: float = 10000.0,
544
+ device: Optional[torch.device] = None,
545
+ **block_kwargs,
546
+ ):
547
+ super().__init__()
548
+ self.in_channels = in_channels
549
+ self.out_channels = in_channels
550
+ self.patch_size = patch_size
551
+ self.num_heads = num_heads
552
+ self.hidden_size_x = hidden_size_x
553
+ self.hidden_size_y = hidden_size_y
554
+ self.head_dim = hidden_size_x // num_heads # Head dimension and count is determined by visual.
555
+ self.use_extended_posenc = use_extended_posenc
556
+ self.t5_token_length = t5_token_length
557
+ self.t5_feat_dim = t5_feat_dim
558
+ self.rope_theta = rope_theta # Scaling factor for frequency computation for temporal RoPE.
559
+
560
+ self.x_embedder = PatchEmbed(
561
+ patch_size=patch_size,
562
+ in_chans=in_channels,
563
+ embed_dim=hidden_size_x,
564
+ bias=patch_embed_bias,
565
+ device=device,
566
+ )
567
+ # Conditionings
568
+ # Timestep
569
+ self.t_embedder = TimestepEmbedder(hidden_size_x, bias=timestep_mlp_bias, timestep_scale=timestep_scale)
570
+
571
+ # Caption Pooling (T5)
572
+ self.t5_y_embedder = AttentionPool(t5_feat_dim, num_heads=8, output_dim=hidden_size_x, device=device)
573
+
574
+ # Dense Embedding Projection (T5)
575
+ self.t5_yproj = nn.Linear(t5_feat_dim, hidden_size_y, bias=True, device=device)
576
+
577
+ # Initialize pos_frequencies as an empty parameter.
578
+ self.pos_frequencies = nn.Parameter(torch.empty(3, self.num_heads, self.head_dim // 2, device=device))
579
+
580
+ # for depth 48:
581
+ # b = 0: AsymmetricJointBlock, update_y=True
582
+ # b = 1: AsymmetricJointBlock, update_y=True
583
+ # ...
584
+ # b = 46: AsymmetricJointBlock, update_y=True
585
+ # b = 47: AsymmetricJointBlock, update_y=False. No need to update text features.
586
+ blocks = []
587
+ for b in range(depth):
588
+ # Joint multi-modal block
589
+ update_y = b < depth - 1
590
+ block = AsymmetricJointBlock(
591
+ hidden_size_x,
592
+ hidden_size_y,
593
+ num_heads,
594
+ mlp_ratio_x=mlp_ratio_x,
595
+ mlp_ratio_y=mlp_ratio_y,
596
+ update_y=update_y,
597
+ device=device,
598
+ **block_kwargs,
599
+ )
600
+
601
+ blocks.append(block)
602
+ self.blocks = nn.ModuleList(blocks)
603
+
604
+ self.final_layer = FinalLayer(hidden_size_x, patch_size, self.out_channels, device=device)
605
+
606
+ def embed_x(self, x: torch.Tensor) -> torch.Tensor:
607
+ """
608
+ Args:
609
+ x: (B, C=12, T, H, W) tensor of visual tokens
610
+
611
+ Returns:
612
+ x: (B, C=3072, N) tensor of visual tokens with positional embedding.
613
+ """
614
+ return self.x_embedder(x) # Convert BcTHW to BCN
615
+
616
+ @torch.compile(disable=not COMPILE_MMDIT_BLOCK)
617
+ def prepare(
618
+ self,
619
+ x: torch.Tensor,
620
+ sigma: torch.Tensor,
621
+ t5_feat: torch.Tensor,
622
+ t5_mask: torch.Tensor,
623
+ ):
624
+ """Prepare input and conditioning embeddings."""
625
+
626
+ # Visual patch embeddings with positional encoding.
627
+ T, H, W = x.shape[-3:]
628
+ pH, pW = H // self.patch_size, W // self.patch_size
629
+ x = self.embed_x(x) # (B, N, D), where N = T * H * W / patch_size ** 2
630
+ assert x.ndim == 3
631
+ B = x.size(0)
632
+
633
+ # Construct position array of size [N, 3].
634
+ # pos[:, 0] is the frame index for each location,
635
+ # pos[:, 1] is the row index for each location, and
636
+ # pos[:, 2] is the column index for each location.
637
+ N = T * pH * pW
638
+ assert x.size(1) == N
639
+ pos = create_position_matrix(T, pH=pH, pW=pW, device=x.device, dtype=torch.float32) # (N, 3)
640
+ rope_cos, rope_sin = compute_mixed_rotation(
641
+ freqs=self.pos_frequencies, pos=pos
642
+ ) # Each are (N, num_heads, dim // 2)
643
+
644
+ # Global vector embedding for conditionings.
645
+ c_t = self.t_embedder(1 - sigma) # (B, D)
646
+
647
+ # Pool T5 tokens using attention pooler
648
+ # Note y_feat[1] contains T5 token features.
649
+ assert (
650
+ t5_feat.size(1) == self.t5_token_length
651
+ ), f"Expected L={self.t5_token_length}, got {t5_feat.shape} for y_feat."
652
+ t5_y_pool = self.t5_y_embedder(t5_feat, t5_mask) # (B, D)
653
+ assert t5_y_pool.size(0) == B, f"Expected B={B}, got {t5_y_pool.shape} for t5_y_pool."
654
+
655
+ c = c_t + t5_y_pool
656
+
657
+ y_feat = self.t5_yproj(t5_feat) # (B, L, t5_feat_dim) --> (B, L, D)
658
+
659
+ return x, c, y_feat, rope_cos, rope_sin
660
+
661
+ def forward(
662
+ self,
663
+ x: torch.Tensor,
664
+ sigma: torch.Tensor,
665
+ y_feat: List[torch.Tensor],
666
+ y_mask: List[torch.Tensor],
667
+ packed_indices: Dict[str, torch.Tensor] = None,
668
+ rope_cos: torch.Tensor = None,
669
+ rope_sin: torch.Tensor = None,
670
+ num_ff_checkpoint: int = 0,
671
+ num_qkv_checkpoint: int = 0,
672
+ num_post_attn_checkpoint: int = 0,
673
+ ):
674
+ """Forward pass of DiT.
675
+
676
+ Args:
677
+ x: (B, C, T, H, W) tensor of spatial inputs (images or latent representations of images)
678
+ sigma: (B,) tensor of noise standard deviations
679
+ y_feat: List((B, L, y_feat_dim) tensor of caption token features. For SDXL text encoders: L=77, y_feat_dim=2048)
680
+ y_mask: List((B, L) boolean tensor indicating which tokens are not padding)
681
+ packed_indices: Dict with keys for Flash Attention. Result of compute_packed_indices.
682
+ """
683
+ _, _, T, H, W = x.shape
684
+
685
+ if self.pos_frequencies.dtype != torch.float32:
686
+ warnings.warn(f"pos_frequencies dtype {self.pos_frequencies.dtype} != torch.float32")
687
+
688
+ # Use EFFICIENT_ATTENTION backend for T5 pooling, since we have a mask.
689
+ # Have to call sdpa_kernel outside of a torch.compile region.
690
+ with sdpa_kernel(torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION):
691
+ x, c, y_feat, rope_cos, rope_sin = self.prepare(x, sigma, y_feat[0], y_mask[0])
692
+ del y_mask
693
+
694
+ cp_rank, cp_size = cp.get_cp_rank_size()
695
+ N = x.size(1)
696
+ M = N // cp_size
697
+ assert N % cp_size == 0, f"Visual sequence length ({x.shape[1]}) must be divisible by cp_size ({cp_size})."
698
+
699
+ if cp_size > 1:
700
+ x = x.narrow(1, cp_rank * M, M)
701
+
702
+ assert self.num_heads % cp_size == 0
703
+ local_heads = self.num_heads // cp_size
704
+ rope_cos = rope_cos.narrow(1, cp_rank * local_heads, local_heads)
705
+ rope_sin = rope_sin.narrow(1, cp_rank * local_heads, local_heads)
706
+
707
+ for i, block in enumerate(self.blocks):
708
+ x, y_feat = block(
709
+ x,
710
+ c,
711
+ y_feat,
712
+ rope_cos=rope_cos,
713
+ rope_sin=rope_sin,
714
+ packed_indices=packed_indices,
715
+ checkpoint_ff=i < num_ff_checkpoint,
716
+ checkpoint_qkv=i < num_qkv_checkpoint,
717
+ checkpoint_post_attn=i < num_post_attn_checkpoint,
718
+ ) # (B, M, D), (B, L, D)
719
+ del y_feat # Final layers don't use dense text features.
720
+
721
+ x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels)
722
+
723
+ patch = x.size(2)
724
+ x = cp.all_gather(x)
725
+ x = rearrange(x, "(G B) M P -> B (G M) P", G=cp_size, P=patch)
726
+ x = rearrange(
727
+ x,
728
+ "B (T hp wp) (p1 p2 c) -> B c T (hp p1) (wp p2)",
729
+ T=T,
730
+ hp=H // self.patch_size,
731
+ wp=W // self.patch_size,
732
+ p1=self.patch_size,
733
+ p2=self.patch_size,
734
+ c=self.out_channels,
735
+ )
736
+
737
+ return x
src/genmo/mochi_preview/dit/joint_model/context_parallel.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import torch
4
+ import torch.distributed as dist
5
+ from einops import rearrange
6
+
7
+ _CONTEXT_PARALLEL_GROUP = None
8
+ _CONTEXT_PARALLEL_RANK = None
9
+ _CONTEXT_PARALLEL_GROUP_SIZE = None
10
+ _CONTEXT_PARALLEL_GROUP_RANKS = None
11
+
12
+
13
+ def get_cp_rank_size() -> Tuple[int, int]:
14
+ if _CONTEXT_PARALLEL_GROUP:
15
+ assert isinstance(_CONTEXT_PARALLEL_RANK, int) and isinstance(_CONTEXT_PARALLEL_GROUP_SIZE, int)
16
+ return _CONTEXT_PARALLEL_RANK, _CONTEXT_PARALLEL_GROUP_SIZE
17
+ else:
18
+ return 0, 1
19
+
20
+
21
+ def local_shard(x: torch.Tensor, dim: int = 2) -> torch.Tensor:
22
+ if not _CONTEXT_PARALLEL_GROUP:
23
+ return x
24
+
25
+ cp_rank, cp_size = get_cp_rank_size()
26
+ return x.tensor_split(cp_size, dim=dim)[cp_rank]
27
+
28
+
29
+ def set_cp_group(cp_group, ranks, global_rank):
30
+ global _CONTEXT_PARALLEL_GROUP, _CONTEXT_PARALLEL_RANK, _CONTEXT_PARALLEL_GROUP_SIZE, _CONTEXT_PARALLEL_GROUP_RANKS
31
+ if _CONTEXT_PARALLEL_GROUP is not None:
32
+ raise RuntimeError("CP group already initialized.")
33
+ _CONTEXT_PARALLEL_GROUP = cp_group
34
+ _CONTEXT_PARALLEL_RANK = dist.get_rank(cp_group)
35
+ _CONTEXT_PARALLEL_GROUP_SIZE = dist.get_world_size(cp_group)
36
+ _CONTEXT_PARALLEL_GROUP_RANKS = ranks
37
+
38
+ assert _CONTEXT_PARALLEL_RANK == ranks.index(
39
+ global_rank
40
+ ), f"Rank mismatch: {global_rank} in {ranks} does not have position {_CONTEXT_PARALLEL_RANK} "
41
+ assert _CONTEXT_PARALLEL_GROUP_SIZE == len(
42
+ ranks
43
+ ), f"Group size mismatch: {_CONTEXT_PARALLEL_GROUP_SIZE} != len({ranks})"
44
+
45
+
46
+ def get_cp_group():
47
+ if _CONTEXT_PARALLEL_GROUP is None:
48
+ raise RuntimeError("CP group not initialized")
49
+ return _CONTEXT_PARALLEL_GROUP
50
+
51
+
52
+ def is_cp_active():
53
+ return _CONTEXT_PARALLEL_GROUP is not None
54
+
55
+
56
+ class AllGatherIntoTensorFunction(torch.autograd.Function):
57
+ @staticmethod
58
+ def forward(ctx, x: torch.Tensor, reduce_dtype, group: dist.ProcessGroup):
59
+ ctx.reduce_dtype = reduce_dtype
60
+ ctx.group = group
61
+ ctx.batch_size = x.size(0)
62
+ group_size = dist.get_world_size(group)
63
+
64
+ x = x.contiguous()
65
+ output = torch.empty(group_size * x.size(0), *x.shape[1:], dtype=x.dtype, device=x.device)
66
+ dist.all_gather_into_tensor(output, x, group=group)
67
+ return output
68
+
69
+
70
+ def all_gather(tensor: torch.Tensor) -> torch.Tensor:
71
+ if not _CONTEXT_PARALLEL_GROUP:
72
+ return tensor
73
+
74
+ return AllGatherIntoTensorFunction.apply(tensor, torch.float32, _CONTEXT_PARALLEL_GROUP)
75
+
76
+
77
+ @torch.compiler.disable()
78
+ def _all_to_all_single(output, input, group):
79
+ # Disable compilation since torch compile changes contiguity.
80
+ assert input.is_contiguous(), "Input tensor must be contiguous."
81
+ assert output.is_contiguous(), "Output tensor must be contiguous."
82
+ return dist.all_to_all_single(output, input, group=group)
83
+
84
+
85
+ class CollectTokens(torch.autograd.Function):
86
+ @staticmethod
87
+ def forward(ctx, qkv: torch.Tensor, group: dist.ProcessGroup, num_heads: int):
88
+ """Redistribute heads and receive tokens.
89
+
90
+ Args:
91
+ qkv: query, key or value. Shape: [B, M, 3 * num_heads * head_dim]
92
+
93
+ Returns:
94
+ qkv: shape: [3, B, N, local_heads, head_dim]
95
+
96
+ where M is the number of local tokens,
97
+ N = cp_size * M is the number of global tokens,
98
+ local_heads = num_heads // cp_size is the number of local heads.
99
+ """
100
+ ctx.group = group
101
+ ctx.num_heads = num_heads
102
+ cp_size = dist.get_world_size(group)
103
+ assert num_heads % cp_size == 0
104
+ ctx.local_heads = num_heads // cp_size
105
+
106
+ qkv = rearrange(
107
+ qkv,
108
+ "B M (qkv G h d) -> G M h B (qkv d)",
109
+ qkv=3,
110
+ G=cp_size,
111
+ h=ctx.local_heads,
112
+ ).contiguous()
113
+
114
+ output_chunks = torch.empty_like(qkv)
115
+ _all_to_all_single(output_chunks, qkv, group=group)
116
+
117
+ return rearrange(output_chunks, "G M h B (qkv d) -> qkv B (G M) h d", qkv=3)
118
+
119
+
120
+ def all_to_all_collect_tokens(x: torch.Tensor, num_heads: int) -> torch.Tensor:
121
+ if not _CONTEXT_PARALLEL_GROUP:
122
+ # Move QKV dimension to the front.
123
+ # B M (3 H d) -> 3 B M H d
124
+ B, M, _ = x.size()
125
+ x = x.view(B, M, 3, num_heads, -1)
126
+ return x.permute(2, 0, 1, 3, 4)
127
+
128
+ return CollectTokens.apply(x, _CONTEXT_PARALLEL_GROUP, num_heads)
129
+
130
+
131
+ class CollectHeads(torch.autograd.Function):
132
+ @staticmethod
133
+ def forward(ctx, x: torch.Tensor, group: dist.ProcessGroup):
134
+ """Redistribute tokens and receive heads.
135
+
136
+ Args:
137
+ x: Output of attention. Shape: [B, N, local_heads, head_dim]
138
+
139
+ Returns:
140
+ Shape: [B, M, num_heads * head_dim]
141
+ """
142
+ ctx.group = group
143
+ ctx.local_heads = x.size(2)
144
+ ctx.head_dim = x.size(3)
145
+ group_size = dist.get_world_size(group)
146
+ x = rearrange(x, "B (G M) h D -> G h M B D", G=group_size).contiguous()
147
+ output = torch.empty_like(x)
148
+ _all_to_all_single(output, x, group=group)
149
+ del x
150
+ return rearrange(output, "G h M B D -> B M (G h D)")
151
+
152
+
153
+ def all_to_all_collect_heads(x: torch.Tensor) -> torch.Tensor:
154
+ if not _CONTEXT_PARALLEL_GROUP:
155
+ # Merge heads.
156
+ return x.view(x.size(0), x.size(1), x.size(2) * x.size(3))
157
+
158
+ return CollectHeads.apply(x, _CONTEXT_PARALLEL_GROUP)
src/genmo/mochi_preview/dit/joint_model/layers.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections.abc
2
+ import math
3
+ from itertools import repeat
4
+ from typing import Callable, Optional
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from einops import rearrange
10
+
11
+
12
+ # From PyTorch internals
13
+ def _ntuple(n):
14
+ def parse(x):
15
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
16
+ return tuple(x)
17
+ return tuple(repeat(x, n))
18
+
19
+ return parse
20
+
21
+
22
+ to_2tuple = _ntuple(2)
23
+
24
+
25
+ class TimestepEmbedder(nn.Module):
26
+ def __init__(
27
+ self,
28
+ hidden_size: int,
29
+ frequency_embedding_size: int = 256,
30
+ *,
31
+ bias: bool = True,
32
+ timestep_scale: Optional[float] = None,
33
+ device: Optional[torch.device] = None,
34
+ ):
35
+ super().__init__()
36
+ self.mlp = nn.Sequential(
37
+ nn.Linear(frequency_embedding_size, hidden_size, bias=bias, device=device),
38
+ nn.SiLU(),
39
+ nn.Linear(hidden_size, hidden_size, bias=bias, device=device),
40
+ )
41
+ self.frequency_embedding_size = frequency_embedding_size
42
+ self.timestep_scale = timestep_scale
43
+
44
+ @staticmethod
45
+ def timestep_embedding(t, dim, max_period=10000):
46
+ half = dim // 2
47
+ freqs = torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
48
+ freqs.mul_(-math.log(max_period) / half).exp_()
49
+ args = t[:, None].float() * freqs[None]
50
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
51
+ if dim % 2:
52
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
53
+ return embedding
54
+
55
+ def forward(self, t):
56
+ if self.timestep_scale is not None:
57
+ t = t * self.timestep_scale
58
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
59
+ t_emb = self.mlp(t_freq)
60
+ return t_emb
61
+
62
+
63
+ class PooledCaptionEmbedder(nn.Module):
64
+ def __init__(
65
+ self,
66
+ caption_feature_dim: int,
67
+ hidden_size: int,
68
+ *,
69
+ bias: bool = True,
70
+ device: Optional[torch.device] = None,
71
+ ):
72
+ super().__init__()
73
+ self.caption_feature_dim = caption_feature_dim
74
+ self.hidden_size = hidden_size
75
+ self.mlp = nn.Sequential(
76
+ nn.Linear(caption_feature_dim, hidden_size, bias=bias, device=device),
77
+ nn.SiLU(),
78
+ nn.Linear(hidden_size, hidden_size, bias=bias, device=device),
79
+ )
80
+
81
+ def forward(self, x):
82
+ return self.mlp(x)
83
+
84
+
85
+ class FeedForward(nn.Module):
86
+ def __init__(
87
+ self,
88
+ in_features: int,
89
+ hidden_size: int,
90
+ multiple_of: int,
91
+ ffn_dim_multiplier: Optional[float],
92
+ device: Optional[torch.device] = None,
93
+ ):
94
+ super().__init__()
95
+ # keep parameter count and computation constant compared to standard FFN
96
+ hidden_size = int(2 * hidden_size / 3)
97
+ # custom dim factor multiplier
98
+ if ffn_dim_multiplier is not None:
99
+ hidden_size = int(ffn_dim_multiplier * hidden_size)
100
+ hidden_size = multiple_of * ((hidden_size + multiple_of - 1) // multiple_of)
101
+
102
+ self.hidden_dim = hidden_size
103
+ self.w1 = nn.Linear(in_features, 2 * hidden_size, bias=False, device=device)
104
+ self.w2 = nn.Linear(hidden_size, in_features, bias=False, device=device)
105
+
106
+ def forward(self, x):
107
+ # assert self.w1.weight.dtype == torch.bfloat16, f"FFN weight dtype {self.w1.weight.dtype} != bfloat16"
108
+ x, gate = self.w1(x).chunk(2, dim=-1)
109
+ x = self.w2(F.silu(x) * gate)
110
+ return x
111
+
112
+
113
+ class PatchEmbed(nn.Module):
114
+ def __init__(
115
+ self,
116
+ patch_size: int = 16,
117
+ in_chans: int = 3,
118
+ embed_dim: int = 768,
119
+ norm_layer: Optional[Callable] = None,
120
+ flatten: bool = True,
121
+ bias: bool = True,
122
+ dynamic_img_pad: bool = False,
123
+ device: Optional[torch.device] = None,
124
+ ):
125
+ super().__init__()
126
+ self.patch_size = to_2tuple(patch_size)
127
+ self.flatten = flatten
128
+ self.dynamic_img_pad = dynamic_img_pad
129
+
130
+ self.proj = nn.Conv2d(
131
+ in_chans,
132
+ embed_dim,
133
+ kernel_size=patch_size,
134
+ stride=patch_size,
135
+ bias=bias,
136
+ device=device,
137
+ )
138
+ assert norm_layer is None
139
+ self.norm = norm_layer(embed_dim, device=device) if norm_layer else nn.Identity()
140
+
141
+ def forward(self, x):
142
+ B, _C, T, H, W = x.shape
143
+ if not self.dynamic_img_pad:
144
+ assert (
145
+ H % self.patch_size[0] == 0
146
+ ), f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})."
147
+ assert (
148
+ W % self.patch_size[1] == 0
149
+ ), f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
150
+ else:
151
+ pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
152
+ pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
153
+ x = F.pad(x, (0, pad_w, 0, pad_h))
154
+
155
+ x = rearrange(x, "B C T H W -> (B T) C H W", B=B, T=T)
156
+ x = self.proj(x)
157
+
158
+ # Flatten temporal and spatial dimensions.
159
+ if not self.flatten:
160
+ raise NotImplementedError("Must flatten output.")
161
+ x = rearrange(x, "(B T) C H W -> B (T H W) C", B=B, T=T)
162
+
163
+ x = self.norm(x)
164
+ return x
165
+
166
+
167
+ class RMSNorm(torch.nn.Module):
168
+ def __init__(self, hidden_size, eps=1e-5, device=None):
169
+ super().__init__()
170
+ self.eps = eps
171
+ self.weight = torch.nn.Parameter(torch.empty(hidden_size, device=device))
172
+ self.register_parameter("bias", None)
173
+
174
+ def forward(self, x):
175
+ # assert self.weight.dtype == torch.float32, f"RMSNorm weight dtype {self.weight.dtype} != float32"
176
+
177
+ x_fp32 = x.float()
178
+ x_normed = x_fp32 * torch.rsqrt(x_fp32.pow(2).mean(-1, keepdim=True) + self.eps)
179
+ return (x_normed * self.weight).type_as(x)
src/genmo/mochi_preview/dit/joint_model/lora.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/env python3
2
+ import math
3
+ from typing import Dict, List, Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+
10
+ class LoRALayer:
11
+ def __init__(
12
+ self,
13
+ r: int,
14
+ lora_alpha: int,
15
+ lora_dropout: float,
16
+ merge_weights: bool,
17
+ ):
18
+ self.r = r
19
+ self.lora_alpha = lora_alpha
20
+ if lora_dropout > 0.0:
21
+ self.lora_dropout = nn.Dropout(p=lora_dropout)
22
+ else:
23
+ self.lora_dropout = lambda x: x
24
+ self.merged = False
25
+ self.merge_weights = merge_weights
26
+
27
+
28
+ def mark_only_lora_as_trainable(model: nn.Module, bias: str = "none") -> None:
29
+ assert bias == "none", f"Only bias='none' is supported"
30
+ for n, p in model.named_parameters():
31
+ if "lora_" not in n:
32
+ p.requires_grad = False
33
+
34
+
35
+ def lora_state_dict(model: nn.Module, bias: str = "none") -> Dict[str, torch.Tensor]:
36
+ assert bias == "none", f"Only bias='none' is supported"
37
+ my_state_dict = model.state_dict()
38
+ return {k: my_state_dict[k] for k in my_state_dict if "lora_" in k}
39
+
40
+
41
+ class LoraLinear(nn.Linear, LoRALayer):
42
+ # LoRA implemented in a dense layer
43
+ def __init__(
44
+ self,
45
+ in_features: int,
46
+ out_features: int,
47
+ r: int = 0,
48
+ lora_alpha: int = 1,
49
+ lora_dropout: float = 0.0,
50
+ fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
51
+ merge_weights: bool = True,
52
+ **kwargs,
53
+ ):
54
+ nn.Linear.__init__(self, in_features, out_features, **kwargs)
55
+ LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights)
56
+
57
+ self.fan_in_fan_out = fan_in_fan_out
58
+ # Actual trainable parameters
59
+ if r > 0:
60
+ self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)).to(torch.float32))
61
+ self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)).to(torch.float32))
62
+ self.scaling = self.lora_alpha / self.r
63
+
64
+ # Freezing the pre-trained weight matrix
65
+ self.weight.requires_grad = False
66
+
67
+ self.reset_parameters()
68
+
69
+ if fan_in_fan_out:
70
+ self.weight.data = self.weight.data.transpose(0, 1)
71
+
72
+ def reset_parameters(self):
73
+ nn.Linear.reset_parameters(self)
74
+ if hasattr(self, "lora_A"):
75
+ # initialize B the same way as the default for nn.Linear and A to zero
76
+ # this is different than what is described in the paper but should not affect performance
77
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
78
+ nn.init.zeros_(self.lora_B)
79
+
80
+ def train(self, mode: bool = True):
81
+ def T(w):
82
+ return w.transpose(0, 1) if self.fan_in_fan_out else w
83
+
84
+ nn.Linear.train(self, mode)
85
+ if mode:
86
+ if self.merge_weights and self.merged:
87
+ # Make sure that the weights are not merged
88
+ if self.r > 0:
89
+ self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
90
+ self.merged = False
91
+ else:
92
+ if self.merge_weights and not self.merged:
93
+ # Merge the weights and mark it
94
+ if self.r > 0:
95
+ self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
96
+ self.merged = True
97
+
98
+ def forward(self, x: torch.Tensor):
99
+ def T(w):
100
+ return w.transpose(0, 1) if self.fan_in_fan_out else w
101
+
102
+ if self.r > 0 and not self.merged:
103
+ result = F.linear(x, T(self.weight), bias=self.bias)
104
+
105
+ x = self.lora_dropout(x)
106
+ x = x @ self.lora_A.transpose(0, 1)
107
+ x = x @ self.lora_B.transpose(0, 1)
108
+ x = x * self.scaling
109
+
110
+ return result + x
111
+ else:
112
+ return F.linear(x, T(self.weight), bias=self.bias)
src/genmo/mochi_preview/dit/joint_model/mod_rmsnorm.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def modulated_rmsnorm(x, scale, eps=1e-6):
5
+ dtype = x.dtype
6
+ x = x.float()
7
+
8
+ # Compute RMS
9
+ mean_square = x.pow(2).mean(-1, keepdim=True)
10
+ inv_rms = torch.rsqrt(mean_square + eps)
11
+
12
+ # Normalize and modulate
13
+ x_normed = x * inv_rms
14
+ x_modulated = x_normed * (1 + scale.unsqueeze(1).float())
15
+ return x_modulated.to(dtype)
src/genmo/mochi_preview/dit/joint_model/residual_tanh_gated_rmsnorm.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def residual_tanh_gated_rmsnorm(x, x_res, gate, eps=1e-6):
5
+ # Convert to fp32 for precision
6
+ x_res = x_res.float()
7
+
8
+ # Compute RMS
9
+ mean_square = x_res.pow(2).mean(-1, keepdim=True)
10
+ scale = torch.rsqrt(mean_square + eps)
11
+
12
+ # Apply tanh to gate
13
+ tanh_gate = torch.tanh(gate).unsqueeze(1)
14
+
15
+ # Normalize and apply gated scaling
16
+ x_normed = x_res * scale * tanh_gate
17
+
18
+ # Apply residual connection
19
+ output = x + x_normed.type_as(x)
20
+ return output
src/genmo/mochi_preview/dit/joint_model/rope_mixed.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import math
3
+
4
+ import torch
5
+
6
+
7
+ def centers(start: float, stop, num, dtype=None, device=None):
8
+ """linspace through bin centers.
9
+
10
+ Args:
11
+ start (float): Start of the range.
12
+ stop (float): End of the range.
13
+ num (int): Number of points.
14
+ dtype (torch.dtype): Data type of the points.
15
+ device (torch.device): Device of the points.
16
+
17
+ Returns:
18
+ centers (Tensor): Centers of the bins. Shape: (num,).
19
+ """
20
+ edges = torch.linspace(start, stop, num + 1, dtype=dtype, device=device)
21
+ return (edges[:-1] + edges[1:]) / 2
22
+
23
+
24
+ @functools.lru_cache(maxsize=1)
25
+ def create_position_matrix(
26
+ T: int,
27
+ pH: int,
28
+ pW: int,
29
+ device: torch.device,
30
+ dtype: torch.dtype,
31
+ *,
32
+ target_area: float = 36864,
33
+ ):
34
+ """
35
+ Args:
36
+ T: int - Temporal dimension
37
+ pH: int - Height dimension after patchify
38
+ pW: int - Width dimension after patchify
39
+
40
+ Returns:
41
+ pos: [T * pH * pW, 3] - position matrix
42
+ """
43
+ with torch.no_grad():
44
+ # Create 1D tensors for each dimension
45
+ t = torch.arange(T, dtype=dtype)
46
+
47
+ # Positionally interpolate to area 36864.
48
+ # (3072x3072 frame with 16x16 patches = 192x192 latents).
49
+ # This automatically scales rope positions when the resolution changes.
50
+ # We use a large target area so the model is more sensitive
51
+ # to changes in the learned pos_frequencies matrix.
52
+ scale = math.sqrt(target_area / (pW * pH))
53
+ w = centers(-pW * scale / 2, pW * scale / 2, pW)
54
+ h = centers(-pH * scale / 2, pH * scale / 2, pH)
55
+
56
+ # Use meshgrid to create 3D grids
57
+ grid_t, grid_h, grid_w = torch.meshgrid(t, h, w, indexing="ij")
58
+
59
+ # Stack and reshape the grids.
60
+ pos = torch.stack([grid_t, grid_h, grid_w], dim=-1) # [T, pH, pW, 3]
61
+ pos = pos.view(-1, 3) # [T * pH * pW, 3]
62
+ pos = pos.to(dtype=dtype, device=device)
63
+
64
+ return pos
65
+
66
+
67
+ def compute_mixed_rotation(
68
+ freqs: torch.Tensor,
69
+ pos: torch.Tensor,
70
+ ):
71
+ """
72
+ Project each 3-dim position into per-head, per-head-dim 1D frequencies.
73
+
74
+ Args:
75
+ freqs: [3, num_heads, num_freqs] - learned rotation frequency (for t, row, col) for each head position
76
+ pos: [N, 3] - position of each token
77
+ num_heads: int
78
+
79
+ Returns:
80
+ freqs_cos: [N, num_heads, num_freqs] - cosine components
81
+ freqs_sin: [N, num_heads, num_freqs] - sine components
82
+ """
83
+ with torch.autocast("cuda", enabled=False):
84
+ assert freqs.ndim == 3
85
+ freqs_sum = torch.einsum("Nd,dhf->Nhf", pos.to(freqs), freqs)
86
+ freqs_cos = torch.cos(freqs_sum)
87
+ freqs_sin = torch.sin(freqs_sum)
88
+ return freqs_cos, freqs_sin
src/genmo/mochi_preview/dit/joint_model/temporal_rope.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Based on Llama3 Implementation.
2
+ import torch
3
+
4
+
5
+ def apply_rotary_emb_qk_real(
6
+ xqk: torch.Tensor,
7
+ freqs_cos: torch.Tensor,
8
+ freqs_sin: torch.Tensor,
9
+ ) -> torch.Tensor:
10
+ """
11
+ Apply rotary embeddings to input tensors using the given frequency tensor without complex numbers.
12
+
13
+ Args:
14
+ xqk (torch.Tensor): Query and/or Key tensors to apply rotary embeddings. Shape: (B, S, *, num_heads, D)
15
+ Can be either just query or just key, or both stacked along some batch or * dim.
16
+ freqs_cos (torch.Tensor): Precomputed cosine frequency tensor.
17
+ freqs_sin (torch.Tensor): Precomputed sine frequency tensor.
18
+
19
+ Returns:
20
+ torch.Tensor: The input tensor with rotary embeddings applied.
21
+ """
22
+ assert xqk.dtype == torch.bfloat16
23
+ # Split the last dimension into even and odd parts
24
+ xqk_even = xqk[..., 0::2]
25
+ xqk_odd = xqk[..., 1::2]
26
+
27
+ # Apply rotation
28
+ cos_part = (xqk_even * freqs_cos - xqk_odd * freqs_sin).type_as(xqk)
29
+ sin_part = (xqk_even * freqs_sin + xqk_odd * freqs_cos).type_as(xqk)
30
+
31
+ # Interleave the results back into the original shape
32
+ out = torch.stack([cos_part, sin_part], dim=-1).flatten(-2)
33
+ assert out.dtype == torch.bfloat16
34
+ return out
src/genmo/mochi_preview/dit/joint_model/utils.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ def modulate(x, shift, scale):
9
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
10
+
11
+
12
+ def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch.Tensor:
13
+ """
14
+ Pool tokens in x using mask.
15
+
16
+ NOTE: We assume x does not require gradients.
17
+
18
+ Args:
19
+ x: (B, L, D) tensor of tokens.
20
+ mask: (B, L) boolean tensor indicating which tokens are not padding.
21
+
22
+ Returns:
23
+ pooled: (B, D) tensor of pooled tokens.
24
+ """
25
+ assert x.size(1) == mask.size(1) # Expected mask to have same length as tokens.
26
+ assert x.size(0) == mask.size(0) # Expected mask to have same batch size as tokens.
27
+ mask = mask[:, :, None].to(dtype=x.dtype)
28
+ mask = mask / mask.sum(dim=1, keepdim=True).clamp(min=1)
29
+ pooled = (x * mask).sum(dim=1, keepdim=keepdim)
30
+ return pooled
31
+
32
+
33
+ class AttentionPool(nn.Module):
34
+ def __init__(
35
+ self,
36
+ embed_dim: int,
37
+ num_heads: int,
38
+ output_dim: int = None,
39
+ device: Optional[torch.device] = None,
40
+ ):
41
+ """
42
+ Args:
43
+ spatial_dim (int): Number of tokens in sequence length.
44
+ embed_dim (int): Dimensionality of input tokens.
45
+ num_heads (int): Number of attention heads.
46
+ output_dim (int): Dimensionality of output tokens. Defaults to embed_dim.
47
+ """
48
+ super().__init__()
49
+ self.num_heads = num_heads
50
+ self.to_kv = nn.Linear(embed_dim, 2 * embed_dim, device=device)
51
+ self.to_q = nn.Linear(embed_dim, embed_dim, device=device)
52
+ self.to_out = nn.Linear(embed_dim, output_dim or embed_dim, device=device)
53
+
54
+ def forward(self, x, mask):
55
+ """
56
+ Args:
57
+ x (torch.Tensor): (B, L, D) tensor of input tokens.
58
+ mask (torch.Tensor): (B, L) boolean tensor indicating which tokens are not padding.
59
+
60
+ NOTE: We assume x does not require gradients.
61
+
62
+ Returns:
63
+ x (torch.Tensor): (B, D) tensor of pooled tokens.
64
+ """
65
+ D = x.size(2)
66
+
67
+ # Construct attention mask, shape: (B, 1, num_queries=1, num_keys=1+L).
68
+ attn_mask = mask[:, None, None, :].bool() # (B, 1, 1, L).
69
+ attn_mask = F.pad(attn_mask, (1, 0), value=True) # (B, 1, 1, 1+L).
70
+
71
+ # Average non-padding token features. These will be used as the query.
72
+ x_pool = pool_tokens(x, mask, keepdim=True) # (B, 1, D)
73
+
74
+ # Concat pooled features to input sequence.
75
+ x = torch.cat([x_pool, x], dim=1) # (B, L+1, D)
76
+
77
+ # Compute queries, keys, values. Only the mean token is used to create a query.
78
+ kv = self.to_kv(x) # (B, L+1, 2 * D)
79
+ q = self.to_q(x[:, 0]) # (B, D)
80
+
81
+ # Extract heads.
82
+ head_dim = D // self.num_heads
83
+ kv = kv.unflatten(2, (2, self.num_heads, head_dim)) # (B, 1+L, 2, H, head_dim)
84
+ kv = kv.transpose(1, 3) # (B, H, 2, 1+L, head_dim)
85
+ k, v = kv.unbind(2) # (B, H, 1+L, head_dim)
86
+ q = q.unflatten(1, (self.num_heads, head_dim)) # (B, H, head_dim)
87
+ q = q.unsqueeze(2) # (B, H, 1, head_dim)
88
+
89
+ # Compute attention.
90
+ x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0) # (B, H, 1, head_dim)
91
+
92
+ # Concatenate heads and run output.
93
+ x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim)
94
+ x = self.to_out(x)
95
+ return x
96
+
97
+
98
+ def pad_and_split_xy(xy, indices, B, N, L, dtype) -> Tuple[torch.Tensor, torch.Tensor]:
99
+ D = xy.size(1)
100
+
101
+ # Pad sequences to (B, N + L, dim).
102
+ assert indices.ndim == 1
103
+ indices = indices.unsqueeze(1).expand(-1, D) # (total,) -> (total, num_heads * head_dim)
104
+ output = torch.zeros(B * (N + L), D, device=xy.device, dtype=dtype)
105
+ output = torch.scatter(output, 0, indices, xy)
106
+ xy = output.view(B, N + L, D)
107
+
108
+ # Split visual and text tokens along the sequence length.
109
+ return torch.tensor_split(xy, (N,), dim=1)