DeepBeepMeep commited on
Commit
59880dc
·
1 Parent(s): 172584e

beta version

Browse files
LICENSE.txt CHANGED
@@ -1,201 +1,17 @@
1
- Apache License
2
- Version 2.0, January 2004
3
- http://www.apache.org/licenses/
4
 
5
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
 
 
 
6
 
7
- 1. Definitions.
 
 
8
 
9
- "License" shall mean the terms and conditions for use, reproduction,
10
- and distribution as defined by Sections 1 through 9 of this document.
11
 
12
- "Licensor" shall mean the copyright owner or entity authorized by
13
- the copyright owner that is granting the License.
14
 
15
- "Legal Entity" shall mean the union of the acting entity and all
16
- other entities that control, are controlled by, or are under common
17
- control with that entity. For the purposes of this definition,
18
- "control" means (i) the power, direct or indirect, to cause the
19
- direction or management of such entity, whether by contract or
20
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
- outstanding shares, or (iii) beneficial ownership of such entity.
22
-
23
- "You" (or "Your") shall mean an individual or Legal Entity
24
- exercising permissions granted by this License.
25
-
26
- "Source" form shall mean the preferred form for making modifications,
27
- including but not limited to software source code, documentation
28
- source, and configuration files.
29
-
30
- "Object" form shall mean any form resulting from mechanical
31
- transformation or translation of a Source form, including but
32
- not limited to compiled object code, generated documentation,
33
- and conversions to other media types.
34
-
35
- "Work" shall mean the work of authorship, whether in Source or
36
- Object form, made available under the License, as indicated by a
37
- copyright notice that is included in or attached to the work
38
- (an example is provided in the Appendix below).
39
-
40
- "Derivative Works" shall mean any work, whether in Source or Object
41
- form, that is based on (or derived from) the Work and for which the
42
- editorial revisions, annotations, elaborations, or other modifications
43
- represent, as a whole, an original work of authorship. For the purposes
44
- of this License, Derivative Works shall not include works that remain
45
- separable from, or merely link (or bind by name) to the interfaces of,
46
- the Work and Derivative Works thereof.
47
-
48
- "Contribution" shall mean any work of authorship, including
49
- the original version of the Work and any modifications or additions
50
- to that Work or Derivative Works thereof, that is intentionally
51
- submitted to Licensor for inclusion in the Work by the copyright owner
52
- or by an individual or Legal Entity authorized to submit on behalf of
53
- the copyright owner. For the purposes of this definition, "submitted"
54
- means any form of electronic, verbal, or written communication sent
55
- to the Licensor or its representatives, including but not limited to
56
- communication on electronic mailing lists, source code control systems,
57
- and issue tracking systems that are managed by, or on behalf of, the
58
- Licensor for the purpose of discussing and improving the Work, but
59
- excluding communication that is conspicuously marked or otherwise
60
- designated in writing by the copyright owner as "Not a Contribution."
61
-
62
- "Contributor" shall mean Licensor and any individual or Legal Entity
63
- on behalf of whom a Contribution has been received by Licensor and
64
- subsequently incorporated within the Work.
65
-
66
- 2. Grant of Copyright License. Subject to the terms and conditions of
67
- this License, each Contributor hereby grants to You a perpetual,
68
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
- copyright license to reproduce, prepare Derivative Works of,
70
- publicly display, publicly perform, sublicense, and distribute the
71
- Work and such Derivative Works in Source or Object form.
72
-
73
- 3. Grant of Patent License. Subject to the terms and conditions of
74
- this License, each Contributor hereby grants to You a perpetual,
75
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
- (except as stated in this section) patent license to make, have made,
77
- use, offer to sell, sell, import, and otherwise transfer the Work,
78
- where such license applies only to those patent claims licensable
79
- by such Contributor that are necessarily infringed by their
80
- Contribution(s) alone or by combination of their Contribution(s)
81
- with the Work to which such Contribution(s) was submitted. If You
82
- institute patent litigation against any entity (including a
83
- cross-claim or counterclaim in a lawsuit) alleging that the Work
84
- or a Contribution incorporated within the Work constitutes direct
85
- or contributory patent infringement, then any patent licenses
86
- granted to You under this License for that Work shall terminate
87
- as of the date such litigation is filed.
88
-
89
- 4. Redistribution. You may reproduce and distribute copies of the
90
- Work or Derivative Works thereof in any medium, with or without
91
- modifications, and in Source or Object form, provided that You
92
- meet the following conditions:
93
-
94
- (a) You must give any other recipients of the Work or
95
- Derivative Works a copy of this License; and
96
-
97
- (b) You must cause any modified files to carry prominent notices
98
- stating that You changed the files; and
99
-
100
- (c) You must retain, in the Source form of any Derivative Works
101
- that You distribute, all copyright, patent, trademark, and
102
- attribution notices from the Source form of the Work,
103
- excluding those notices that do not pertain to any part of
104
- the Derivative Works; and
105
-
106
- (d) If the Work includes a "NOTICE" text file as part of its
107
- distribution, then any Derivative Works that You distribute must
108
- include a readable copy of the attribution notices contained
109
- within such NOTICE file, excluding those notices that do not
110
- pertain to any part of the Derivative Works, in at least one
111
- of the following places: within a NOTICE text file distributed
112
- as part of the Derivative Works; within the Source form or
113
- documentation, if provided along with the Derivative Works; or,
114
- within a display generated by the Derivative Works, if and
115
- wherever such third-party notices normally appear. The contents
116
- of the NOTICE file are for informational purposes only and
117
- do not modify the License. You may add Your own attribution
118
- notices within Derivative Works that You distribute, alongside
119
- or as an addendum to the NOTICE text from the Work, provided
120
- that such additional attribution notices cannot be construed
121
- as modifying the License.
122
-
123
- You may add Your own copyright statement to Your modifications and
124
- may provide additional or different license terms and conditions
125
- for use, reproduction, or distribution of Your modifications, or
126
- for any such Derivative Works as a whole, provided Your use,
127
- reproduction, and distribution of the Work otherwise complies with
128
- the conditions stated in this License.
129
-
130
- 5. Submission of Contributions. Unless You explicitly state otherwise,
131
- any Contribution intentionally submitted for inclusion in the Work
132
- by You to the Licensor shall be under the terms and conditions of
133
- this License, without any additional terms or conditions.
134
- Notwithstanding the above, nothing herein shall supersede or modify
135
- the terms of any separate license agreement you may have executed
136
- with Licensor regarding such Contributions.
137
-
138
- 6. Trademarks. This License does not grant permission to use the trade
139
- names, trademarks, service marks, or product names of the Licensor,
140
- except as required for reasonable and customary use in describing the
141
- origin of the Work and reproducing the content of the NOTICE file.
142
-
143
- 7. Disclaimer of Warranty. Unless required by applicable law or
144
- agreed to in writing, Licensor provides the Work (and each
145
- Contributor provides its Contributions) on an "AS IS" BASIS,
146
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
- implied, including, without limitation, any warranties or conditions
148
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
- PARTICULAR PURPOSE. You are solely responsible for determining the
150
- appropriateness of using or redistributing the Work and assume any
151
- risks associated with Your exercise of permissions under this License.
152
-
153
- 8. Limitation of Liability. In no event and under no legal theory,
154
- whether in tort (including negligence), contract, or otherwise,
155
- unless required by applicable law (such as deliberate and grossly
156
- negligent acts) or agreed to in writing, shall any Contributor be
157
- liable to You for damages, including any direct, indirect, special,
158
- incidental, or consequential damages of any character arising as a
159
- result of this License or out of the use or inability to use the
160
- Work (including but not limited to damages for loss of goodwill,
161
- work stoppage, computer failure or malfunction, or any and all
162
- other commercial damages or losses), even if such Contributor
163
- has been advised of the possibility of such damages.
164
-
165
- 9. Accepting Warranty or Additional Liability. While redistributing
166
- the Work or Derivative Works thereof, You may choose to offer,
167
- and charge a fee for, acceptance of support, warranty, indemnity,
168
- or other liability obligations and/or rights consistent with this
169
- License. However, in accepting such obligations, You may act only
170
- on Your own behalf and on Your sole responsibility, not on behalf
171
- of any other Contributor, and only if You agree to indemnify,
172
- defend, and hold each Contributor harmless for any liability
173
- incurred by, or claims asserted against, such Contributor by reason
174
- of your accepting any such warranty or additional liability.
175
-
176
- END OF TERMS AND CONDITIONS
177
-
178
- APPENDIX: How to apply the Apache License to your work.
179
-
180
- To apply the Apache License to your work, attach the following
181
- boilerplate notice, with the fields enclosed by brackets "[]"
182
- replaced with your own identifying information. (Don't include
183
- the brackets!) The text should be enclosed in the appropriate
184
- comment syntax for the file format. We also recommend that a
185
- file or class name and description of purpose be included on the
186
- same "printed page" as the copyright notice for easier
187
- identification within third-party archives.
188
-
189
- Copyright [yyyy] [name of copyright owner]
190
-
191
- Licensed under the Apache License, Version 2.0 (the "License");
192
- you may not use this file except in compliance with the License.
193
- You may obtain a copy of the License at
194
-
195
- http://www.apache.org/licenses/LICENSE-2.0
196
-
197
- Unless required by applicable law or agreed to in writing, software
198
- distributed under the License is distributed on an "AS IS" BASIS,
199
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
- See the License for the specific language governing permissions and
201
- limitations under the License.
 
1
+ FREE for Non Commercial USE
 
 
2
 
3
+ You are free to:
4
+ - Share — copy and redistribute the material in any medium or format
5
+ - Adapt — remix, transform, and build upon the material
6
+ The licensor cannot revoke these freedoms as long as you follow the license terms.
7
 
8
+ Under the following terms:
9
+ - Attribution — You must give appropriate credit , provide a link to the license, and indicate if changes were made . You may do so in any reasonable manner, but not in any way that suggests the licensor endorses you or your use.
10
+ NonCommercial — You may not use the material for commercial purposes .
11
 
12
+ - No additional restrictions — You may not apply legal terms or technological measures that legally restrict others from doing anything the license permits.
13
+ Notices:
14
 
15
+ - You do not have to comply with the license for elements of the material in the public domain or where your use is permitted by an applicable exception or limitation .
 
16
 
17
+ No warranties are given. The license may not give you all of the permissions necessary for your intended use. For example, other rights such as publicity, privacy, or moral rights may limit how you use the material.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -27,378 +27,159 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
27
 
28
  ## 🔥 Latest News!!
29
 
 
30
  * Feb 25, 2025: 👋 We've released the inference code and weights of Wan2.1.
31
  * Feb 27, 2025: 👋 Wan2.1 has been integrated into [ComfyUI](https://comfyanonymous.github.io/ComfyUI_examples/wan/). Enjoy!
32
 
33
 
34
- ## 📑 Todo List
35
- - Wan2.1 Text-to-Video
36
- - [x] Multi-GPU Inference code of the 14B and 1.3B models
37
- - [x] Checkpoints of the 14B and 1.3B models
38
- - [x] Gradio demo
39
- - [x] ComfyUI integration
40
- - [ ] Diffusers integration
41
- - Wan2.1 Image-to-Video
42
- - [x] Multi-GPU Inference code of the 14B model
43
- - [x] Checkpoints of the 14B model
44
- - [x] Gradio demo
45
- - [X] ComfyUI integration
46
- - [ ] Diffusers integration
47
-
48
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- ## Quickstart
51
 
52
- #### Installation
53
- Clone the repo:
54
- ```
55
- git clone https://github.com/Wan-Video/Wan2.1.git
56
- cd Wan2.1
57
- ```
58
-
59
- Install dependencies:
60
- ```
61
- # Ensure torch >= 2.4.0
62
- pip install -r requirements.txt
63
- ```
64
-
65
-
66
- #### Model Download
67
-
68
- | Models | Download Link | Notes |
69
- | --------------|-------------------------------------------------------------------------------|-------------------------------|
70
- | T2V-14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B) | Supports both 480P and 720P
71
- | I2V-14B-720P | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P) | Supports 720P
72
- | I2V-14B-480P | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-480P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P) | Supports 480P
73
- | T2V-1.3B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) | Supports 480P
74
-
75
- > 💡Note: The 1.3B model is capable of generating videos at 720P resolution. However, due to limited training at this resolution, the results are generally less stable compared to 480P. For optimal performance, we recommend using 480P resolution.
76
-
77
-
78
- Download models using huggingface-cli:
79
- ```
80
- pip install "huggingface_hub[cli]"
81
- huggingface-cli download Wan-AI/Wan2.1-T2V-14B --local-dir ./Wan2.1-T2V-14B
82
- ```
83
-
84
- Download models using modelscope-cli:
85
- ```
86
- pip install modelscope
87
- modelscope download Wan-AI/Wan2.1-T2V-14B --local_dir ./Wan2.1-T2V-14B
88
- ```
89
- #### Run Text-to-Video Generation
90
-
91
- This repository supports two Text-to-Video models (1.3B and 14B) and two resolutions (480P and 720P). The parameters and configurations for these models are as follows:
92
-
93
- <table>
94
- <thead>
95
- <tr>
96
- <th rowspan="2">Task</th>
97
- <th colspan="2">Resolution</th>
98
- <th rowspan="2">Model</th>
99
- </tr>
100
- <tr>
101
- <th>480P</th>
102
- <th>720P</th>
103
- </tr>
104
- </thead>
105
- <tbody>
106
- <tr>
107
- <td>t2v-14B</td>
108
- <td style="color: green;">✔️</td>
109
- <td style="color: green;">✔️</td>
110
- <td>Wan2.1-T2V-14B</td>
111
- </tr>
112
- <tr>
113
- <td>t2v-1.3B</td>
114
- <td style="color: green;">✔️</td>
115
- <td style="color: red;">❌</td>
116
- <td>Wan2.1-T2V-1.3B</td>
117
- </tr>
118
- </tbody>
119
- </table>
120
-
121
-
122
- ##### (1) Without Prompt Extention
123
-
124
- To facilitate implementation, we will start with a basic version of the inference process that skips the [prompt extension](#2-using-prompt-extention) step.
125
-
126
- - Single-GPU inference
127
-
128
- ```
129
- python generate.py --task t2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-T2V-14B --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
130
- ```
131
-
132
- If you encounter OOM (Out-of-Memory) issues, you can use the `--offload_model True` and `--t5_cpu` options to reduce GPU memory usage. For example, on an RTX 4090 GPU:
133
-
134
- ```
135
- python generate.py --task t2v-1.3B --size 832*480 --ckpt_dir ./Wan2.1-T2V-1.3B --offload_model True --t5_cpu --sample_shift 8 --sample_guide_scale 6 --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
136
- ```
137
 
138
- > 💡Note: If you are using the `T2V-1.3B` model, we recommend setting the parameter `--sample_guide_scale 6`. The `--sample_shift parameter` can be adjusted within the range of 8 to 12 based on the performance.
139
 
 
140
 
141
- - Multi-GPU inference using FSDP + xDiT USP
142
 
143
- ```
144
- pip install "xfuser>=0.4.1"
145
- torchrun --nproc_per_node=8 generate.py --task t2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-T2V-14B --dit_fsdp --t5_fsdp --ulysses_size 8 --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
146
- ```
147
 
148
 
149
- ##### (2) Using Prompt Extention
150
 
151
- Extending the prompts can effectively enrich the details in the generated videos, further enhancing the video quality. Therefore, we recommend enabling prompt extension. We provide the following two methods for prompt extension:
 
152
 
153
- - Use the Dashscope API for extension.
154
- - Apply for a `dashscope.api_key` in advance ([EN](https://www.alibabacloud.com/help/en/model-studio/getting-started/first-api-call-to-qwen) | [CN](https://help.aliyun.com/zh/model-studio/getting-started/first-api-call-to-qwen)).
155
- - Configure the environment variable `DASH_API_KEY` to specify the Dashscope API key. For users of Alibaba Cloud's international site, you also need to set the environment variable `DASH_API_URL` to 'https://dashscope-intl.aliyuncs.com/api/v1'. For more detailed instructions, please refer to the [dashscope document](https://www.alibabacloud.com/help/en/model-studio/developer-reference/use-qwen-by-calling-api?spm=a2c63.p38356.0.i1).
156
- - Use the `qwen-plus` model for text-to-video tasks and `qwen-vl-max` for image-to-video tasks.
157
- - You can modify the model used for extension with the parameter `--prompt_extend_model`. For example:
158
- ```
159
- DASH_API_KEY=your_key python generate.py --task t2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-T2V-14B --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage" --use_prompt_extend --prompt_extend_method 'dashscope' --prompt_extend_target_lang 'ch'
160
- ```
161
 
162
- - Using a local model for extension.
 
 
 
163
 
164
- - By default, the Qwen model on HuggingFace is used for this extension. Users can choose Qwen models or other models based on the available GPU memory size.
165
- - For text-to-video tasks, you can use models like `Qwen/Qwen2.5-14B-Instruct`, `Qwen/Qwen2.5-7B-Instruct` and `Qwen/Qwen2.5-3B-Instruct`.
166
- - For image-to-video tasks, you can use models like `Qwen/Qwen2.5-VL-7B-Instruct` and `Qwen/Qwen2.5-VL-3B-Instruct`.
167
- - Larger models generally provide better extension results but require more GPU memory.
168
- - You can modify the model used for extension with the parameter `--prompt_extend_model` , allowing you to specify either a local model path or a Hugging Face model. For example:
169
 
170
- ```
171
- python generate.py --task t2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-T2V-14B --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage" --use_prompt_extend --prompt_extend_method 'local_qwen' --prompt_extend_target_lang 'ch'
172
- ```
173
 
174
- ##### (3) Runing local gradio
175
 
176
- ```
177
- cd gradio
178
- # if one uses dashscope’s API for prompt extension
179
- DASH_API_KEY=your_key python t2v_14B_singleGPU.py --prompt_extend_method 'dashscope' --ckpt_dir ./Wan2.1-T2V-14B
180
 
181
- # if one uses a local model for prompt extension
182
- python t2v_14B_singleGPU.py --prompt_extend_method 'local_qwen' --ckpt_dir ./Wan2.1-T2V-14B
183
- ```
184
 
 
 
 
 
185
 
186
- #### Run Image-to-Video Generation
187
-
188
- Similar to Text-to-Video, Image-to-Video is also divided into processes with and without the prompt extension step. The specific parameters and their corresponding settings are as follows:
189
- <table>
190
- <thead>
191
- <tr>
192
- <th rowspan="2">Task</th>
193
- <th colspan="2">Resolution</th>
194
- <th rowspan="2">Model</th>
195
- </tr>
196
- <tr>
197
- <th>480P</th>
198
- <th>720P</th>
199
- </tr>
200
- </thead>
201
- <tbody>
202
- <tr>
203
- <td>i2v-14B</td>
204
- <td style="color: green;">❌</td>
205
- <td style="color: green;">✔️</td>
206
- <td>Wan2.1-I2V-14B-720P</td>
207
- </tr>
208
- <tr>
209
- <td>i2v-14B</td>
210
- <td style="color: green;">✔️</td>
211
- <td style="color: red;">❌</td>
212
- <td>Wan2.1-T2V-14B-480P</td>
213
- </tr>
214
- </tbody>
215
- </table>
216
-
217
-
218
- ##### (1) Without Prompt Extention
219
-
220
- - Single-GPU inference
221
- ```
222
- python generate.py --task i2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-I2V-14B-720P --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
223
- ```
224
 
225
- > 💡For the Image-to-Video task, the `size` parameter represents the area of the generated video, with the aspect ratio following that of the original input image.
226
 
227
 
228
- - Multi-GPU inference using FSDP + xDiT USP
229
 
230
  ```
231
- pip install "xfuser>=0.4.1"
232
- torchrun --nproc_per_node=8 generate.py --task i2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-I2V-14B-720P --image examples/i2v_input.JPG --dit_fsdp --t5_fsdp --ulysses_size 8 --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
233
- ```
234
-
235
- ##### (2) Using Prompt Extention
236
-
237
 
238
- The process of prompt extension can be referenced [here](#2-using-prompt-extention).
 
239
 
240
- Run with local prompt extention using `Qwen/Qwen2.5-VL-7B-Instruct`:
 
 
241
  ```
242
- python generate.py --task i2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-I2V-14B-720P --image examples/i2v_input.JPG --use_prompt_extend --prompt_extend_model Qwen/Qwen2.5-VL-7B-Instruct --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
243
  ```
244
 
245
- Run with remote prompt extention using `dashscope`:
246
  ```
247
- DASH_API_KEY=your_key python generate.py --task i2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-I2V-14B-720P --image examples/i2v_input.JPG --use_prompt_extend --prompt_extend_method 'dashscope' --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
 
248
  ```
249
 
250
- ##### (3) Runing local gradio
251
 
 
 
 
252
  ```
253
- cd gradio
254
- # if one only uses 480P model in gradio
255
- DASH_API_KEY=your_key python i2v_14B_singleGPU.py --prompt_extend_method 'dashscope' --ckpt_dir_480p ./Wan2.1-I2V-14B-480P
256
 
257
- # if one only uses 720P model in gradio
258
- DASH_API_KEY=your_key python i2v_14B_singleGPU.py --prompt_extend_method 'dashscope' --ckpt_dir_720p ./Wan2.1-I2V-14B-720P
259
 
260
- # if one uses both 480P and 720P models in gradio
261
- DASH_API_KEY=your_key python i2v_14B_singleGPU.py --prompt_extend_method 'dashscope' --ckpt_dir_480p ./Wan2.1-I2V-14B-480P --ckpt_dir_720p ./Wan2.1-I2V-14B-720P
262
- ```
263
-
264
-
265
- #### Run Text-to-Image Generation
266
 
267
- Wan2.1 is a unified model for both image and video generation. Since it was trained on both types of data, it can also generate images. The command for generating images is similar to video generation, as follows:
268
 
269
- ##### (1) Without Prompt Extention
270
 
271
- - Single-GPU inference
272
- ```
273
- python generate.py --task t2i-14B --size 1024*1024 --ckpt_dir ./Wan2.1-T2V-14B --prompt '一个朴素端庄的美人'
274
- ```
275
 
276
- - Multi-GPU inference using FSDP + xDiT USP
277
 
278
- ```
279
- torchrun --nproc_per_node=8 generate.py --dit_fsdp --t5_fsdp --ulysses_size 8 --base_seed 0 --frame_num 1 --task t2i-14B --size 1024*1024 --prompt '一个朴素端庄的美人' --ckpt_dir ./Wan2.1-T2V-14B
 
280
  ```
281
 
282
- ##### (2) With Prompt Extention
283
 
284
- - Single-GPU inference
285
- ```
286
- python generate.py --task t2i-14B --size 1024*1024 --ckpt_dir ./Wan2.1-T2V-14B --prompt '一个朴素端庄的美人' --use_prompt_extend
287
- ```
288
 
289
- - Multi-GPU inference using FSDP + xDiT USP
290
- ```
291
- torchrun --nproc_per_node=8 generate.py --dit_fsdp --t5_fsdp --ulysses_size 8 --base_seed 0 --frame_num 1 --task t2i-14B --size 1024*1024 --ckpt_dir ./Wan2.1-T2V-14B --prompt '一个朴素端庄的美人' --use_prompt_extend
292
- ```
293
 
 
 
 
 
 
 
 
 
 
 
 
294
 
295
- ## Manual Evaluation
 
 
 
 
 
 
296
 
297
- ##### (1) Text-to-Video Evaluation
 
298
 
299
- Through manual evaluation, the results generated after prompt extension are superior to those from both closed-source and open-source models.
300
 
301
- <div align="center">
302
- <img src="assets/t2v_res.jpg" alt="" style="width: 80%;" />
303
- </div>
304
-
305
-
306
- ##### (2) Image-to-Video Evaluation
307
-
308
- We also conducted extensive manual evaluations to evaluate the performance of the Image-to-Video model, and the results are presented in the table below. The results clearly indicate that **Wan2.1** outperforms both closed-source and open-source models.
309
-
310
- <div align="center">
311
- <img src="assets/i2v_res.png" alt="" style="width: 80%;" />
312
- </div>
313
-
314
-
315
- ## Computational Efficiency on Different GPUs
316
-
317
- We test the computational efficiency of different **Wan2.1** models on different GPUs in the following table. The results are presented in the format: **Total time (s) / peak GPU memory (GB)**.
318
-
319
-
320
- <div align="center">
321
- <img src="assets/comp_effic.png" alt="" style="width: 80%;" />
322
- </div>
323
-
324
- > The parameter settings for the tests presented in this table are as follows:
325
- > (1) For the 1.3B model on 8 GPUs, set `--ring_size 8` and `--ulysses_size 1`;
326
- > (2) For the 14B model on 1 GPU, use `--offload_model True`;
327
- > (3) For the 1.3B model on a single 4090 GPU, set `--offload_model True --t5_cpu`;
328
- > (4) For all testings, no prompt extension was applied, meaning `--use_prompt_extend` was not enabled.
329
-
330
- > 💡Note: T2V-14B is slower than I2V-14B because the former samples 50 steps while the latter uses 40 steps.
331
-
332
-
333
- ## Community Contributions
334
- - [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) provides more support for **Wan2.1**, including video-to-video, FP8 quantization, VRAM optimization, LoRA training, and more. Please refer to [their examples](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo).
335
-
336
- -------
337
-
338
- ## Introduction of Wan2.1
339
-
340
- **Wan2.1** is designed on the mainstream diffusion transformer paradigm, achieving significant advancements in generative capabilities through a series of innovations. These include our novel spatio-temporal variational autoencoder (VAE), scalable training strategies, large-scale data construction, and automated evaluation metrics. Collectively, these contributions enhance the model’s performance and versatility.
341
-
342
-
343
- ##### (1) 3D Variational Autoencoders
344
- We propose a novel 3D causal VAE architecture, termed **Wan-VAE** specifically designed for video generation. By combining multiple strategies, we improve spatio-temporal compression, reduce memory usage, and ensure temporal causality. **Wan-VAE** demonstrates significant advantages in performance efficiency compared to other open-source VAEs. Furthermore, our **Wan-VAE** can encode and decode unlimited-length 1080P videos without losing historical temporal information, making it particularly well-suited for video generation tasks.
345
-
346
-
347
- <div align="center">
348
- <img src="assets/video_vae_res.jpg" alt="" style="width: 80%;" />
349
- </div>
350
-
351
-
352
- ##### (2) Video Diffusion DiT
353
-
354
- **Wan2.1** is designed using the Flow Matching framework within the paradigm of mainstream Diffusion Transformers. Our model's architecture uses the T5 Encoder to encode multilingual text input, with cross-attention in each transformer block embedding the text into the model structure. Additionally, we employ an MLP with a Linear layer and a SiLU layer to process the input time embeddings and predict six modulation parameters individually. This MLP is shared across all transformer blocks, with each block learning a distinct set of biases. Our experimental findings reveal a significant performance improvement with this approach at the same parameter scale.
355
-
356
- <div align="center">
357
- <img src="assets/video_dit_arch.jpg" alt="" style="width: 80%;" />
358
- </div>
359
-
360
-
361
- | Model | Dimension | Input Dimension | Output Dimension | Feedforward Dimension | Frequency Dimension | Number of Heads | Number of Layers |
362
- |--------|-----------|-----------------|------------------|-----------------------|---------------------|-----------------|------------------|
363
- | 1.3B | 1536 | 16 | 16 | 8960 | 256 | 12 | 30 |
364
- | 14B | 5120 | 16 | 16 | 13824 | 256 | 40 | 40 |
365
-
366
-
367
-
368
- ##### Data
369
-
370
- We curated and deduplicated a candidate dataset comprising a vast amount of image and video data. During the data curation process, we designed a four-step data cleaning process, focusing on fundamental dimensions, visual quality and motion quality. Through the robust data processing pipeline, we can easily obtain high-quality, diverse, and large-scale training sets of images and videos.
371
-
372
- ![figure1](assets/data_for_diff_stage.jpg "figure1")
373
-
374
-
375
- ##### Comparisons to SOTA
376
- We compared **Wan2.1** with leading open-source and closed-source models to evaluate the performace. Using our carefully designed set of 1,035 internal prompts, we tested across 14 major dimensions and 26 sub-dimensions. We then compute the total score by performing a weighted calculation on the scores of each dimension, utilizing weights derived from human preferences in the matching process. The detailed results are shown in the table below. These results demonstrate our model's superior performance compared to both open-source and closed-source models.
377
-
378
- ![figure1](assets/vben_vs_sota.png "figure1")
379
-
380
-
381
- ## Citation
382
- If you find our work helpful, please cite us.
383
-
384
- ```
385
- @article{wan2.1,
386
- title = {Wan: Open and Advanced Large-Scale Video Generative Models},
387
- author = {Wan Team},
388
- journal = {},
389
- year = {2025}
390
- }
391
- ```
392
 
393
- ## License Agreement
394
- The models in this repository are licensed under the Apache 2.0 License. We claim no rights over the your generate contents, granting you the freedom to use them while ensuring that your usage complies with the provisions of this license. You are fully accountable for your use of the models, which must not involve sharing any content that violates applicable laws, causes harm to individuals or groups, disseminates personal information intended for harm, spreads misinformation, or targets vulnerable populations. For a complete list of restrictions and details regarding your rights, please refer to the full text of the [license](LICENSE.txt).
395
 
 
 
396
 
397
- ## Acknowledgements
 
398
 
399
- We would like to thank the contributors to the [SD3](https://huggingface.co/stabilityai/stable-diffusion-3-medium), [Qwen](https://huggingface.co/Qwen), [umt5-xxl](https://huggingface.co/google/umt5-xxl), [diffusers](https://github.com/huggingface/diffusers) and [HuggingFace](https://huggingface.co) repositories, for their open research.
 
400
 
 
 
401
 
402
 
403
- ## Contact Us
404
- If you would like to leave a message to our research or product teams, feel free to join our [Discord](https://discord.gg/p5XbdQV7) or [WeChat groups](https://gw.alicdn.com/imgextra/i2/O1CN01tqjWFi1ByuyehkTSB_!!6000000000015-0-tps-611-1279.jpg)!
 
27
 
28
  ## 🔥 Latest News!!
29
 
30
+ * Mar 03, 2025: Wan2.1GP DeepBeepMeep out of this World version ! Reduced memory consumption by 2, with possiblity to generate more than 10s of video at 720p
31
  * Feb 25, 2025: 👋 We've released the inference code and weights of Wan2.1.
32
  * Feb 27, 2025: 👋 Wan2.1 has been integrated into [ComfyUI](https://comfyanonymous.github.io/ComfyUI_examples/wan/). Enjoy!
33
 
34
 
35
+ ## Features
36
+ *GPU Poor version by **DeepBeepMeep**. This great video generator can now run smoothly on any GPU.*
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ This version has the following improvements over the original Hunyuan Video model:
39
+ - Reduce greatly the RAM requirements and VRAM requirements
40
+ - Much faster thanks to compilation and fast loading / unloading
41
+ - 5 profiles in order to able to run the model at a decent speed on a low end consumer config (32 GB of RAM and 12 VRAM) and to run it at a very good speed on a high end consumer config (48 GB of RAM and 24 GB of VRAM)
42
+ - Autodownloading of the needed model files
43
+ - Improved gradio interface with progression bar and more options
44
+ - Multiples prompts / multiple generations per prompt
45
+ - Support multiple pretrained Loras with 32 GB of RAM or less
46
+ - Switch easily between Hunyuan and Fast Hunyuan models and quantized / non quantized models
47
+ - Much simpler installation
48
 
 
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
+ This fork by DeepBeepMeep is an integration of the mmpg module on the gradio_server.py.
52
 
53
+ It is an illustration on how one can set up on an existing model some fast and properly working CPU offloading with changing only a few lines of code in the core model.
54
 
55
+ For more information on how to use the mmpg module, please go to: https://github.com/deepbeepmeep/mmgp
56
 
57
+ You will find the original Hunyuan Video repository here: https://github.com/deepbeepmeep/Wan2GP
58
+
 
 
59
 
60
 
61
+ ## Installation Guide for Linux and Windows
62
 
63
+ We provide an `environment.yml` file for setting up a Conda environment.
64
+ Conda's installation instructions are available [here](https://docs.anaconda.com/free/miniconda/index.html).
65
 
66
+ This app has been tested on Python 3.10 / 2.6.0 / Cuda 12.4.\
 
 
 
 
 
 
 
67
 
68
+ ```shell
69
+ # 1 - conda. Prepare and activate a conda environment
70
+ conda env create -f environment.yml
71
+ conda activate Wan2
72
 
73
+ # OR
 
 
 
 
74
 
75
+ # 1 - venv. Alternatively create a python 3.10 venv and then do the following
76
+ pip install torch==2.6.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu124
 
77
 
 
78
 
79
+ # 2. Install pip dependencies
80
+ python -m pip install -r requirements.txt
 
 
81
 
82
+ # 3.1 optional Sage attention support (30% faster, easy to install on Linux but much harder on Windows)
83
+ python -m pip install sageattention==1.0.6
 
84
 
85
+ # or for Sage Attention 2 (40% faster, sorry only manual compilation for the moment)
86
+ git pull https://github.com/thu-ml/SageAttention
87
+ cd sageattention
88
+ pip install -e .
89
 
90
+ # 3.2 optional Flash attention support (easy to install on Linux but much harder on Windows)
91
+ python -m pip install flash-attn==2.7.2.post1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
 
93
 
94
 
 
95
 
96
  ```
 
 
 
 
 
 
97
 
98
+ Note that *Flash attention* and *Sage attention* are quite complex to install on Windows but offers a better memory management (and consequently longer videos) than the default *sdpa attention*.
99
+ Likewise *Pytorch Compilation* will work on Windows only if you manage to install Triton. It is quite a complex process (see below for links).
100
 
101
+ ### Ready to use python wheels for Windows users
102
+ I provide here links to simplify the installation for Windows users with Python 3.10 / Pytorch 2.51 / Cuda 12.4. As I am not hosting these files I won't be able to provide support neither guarantee they do what they should do.
103
+ - Triton attention (needed for *pytorch compilation* and *Sage attention*)
104
  ```
105
+ pip install https://github.com/woct0rdho/triton-windows/releases/download/v3.2.0-windows.post9/triton-3.2.0-cp310-cp310-win_amd64.whl # triton for pytorch 2.6.0
106
  ```
107
 
108
+ - Sage attention
109
  ```
110
+ pip install https://github.com/deepbeepmeep/SageAttention/raw/refs/heads/main/releases/sageattention-2.1.0-cp310-cp310-win_amd64.whl # for pytorch 2.6.0 (experimental, if it works, otherwise you you will need to install and compile manually, see above)
111
+
112
  ```
113
 
114
+ ## Run the application
115
 
116
+ ### Run a Gradio Server on port 7860 (recommended)
117
+ ```bash
118
+ python gradio_server.py
119
  ```
 
 
 
120
 
 
 
121
 
122
+ ### Loras support
 
 
 
 
 
123
 
124
+ -- Ready to be used but theorical as no lora for Wan have been released as today.
125
 
126
+ Every lora stored in the subfoler 'loras' will be automatically loaded. You will be then able to activate / desactive any of them when running the application.
127
 
128
+ For each activated Lora, you may specify a *multiplier* that is one float number that corresponds to its weight (default is 1.0), alternatively you may specify a list of floats multipliers separated by a "," that gives the evolution of this Lora's multiplier over the steps. For instance let's assume there are 30 denoising steps and the multiplier is *0.9,0.8,0.7* then for the steps ranges 0-9, 10-19 and 20-29 the Lora multiplier will be respectively 0.9, 0.8 and 0.7.
 
 
 
129
 
130
+ You can edit, save or delete Loras presets (combinations of loras with their corresponding multipliers) directly from the gradio interface. Each preset, is a file with ".lset" extension stored in the loras directory and can be shared with other users
131
 
132
+ Then you can pre activate loras corresponding to a preset when launching the gradio server:
133
+ ```bash
134
+ python gradio_server.py --lora-preset mylorapreset.lset # where 'mylorapreset.lset' is a preset stored in the 'loras' folder
135
  ```
136
 
137
+ Please note that command line parameters *--lora-weight* and *--lora-multiplier* have been deprecated since they are redundant with presets.
138
 
139
+ You will find prebuilt Loras on https://civitai.com/ or you will be able to build them with tools such as kohya or onetrainer.
 
 
 
140
 
 
 
 
 
141
 
142
+ ### Command line parameters for Gradio Server
143
+ --profile no : default (4) : no of profile between 1 and 5\
144
+ --quantize-transformer bool: (default True) : enable / disable on the fly transformer quantization\
145
+ --lora-dir path : Path of directory that contains Loras in diffusers / safetensor format\
146
+ --lora-preset preset : name of preset gile (without the extension) to preload
147
+ --verbose level : default (1) : level of information between 0 and 2\
148
+ --server-port portno : default (7860) : Gradio port no\
149
+ --server-name name : default (0.0.0.0) : Gradio server name\
150
+ --open-browser : open automatically Browser when launching Gradio Server\
151
+ --compile : turn on pytorch compilation\
152
+ --attention mode: force attention mode among, sdpa, flash, sage, sage2\
153
 
154
+ ### Profiles (for power users only)
155
+ You can choose between 5 profiles, these will try to leverage the most your hardware, but have little impact for HunyuanVideo GP:
156
+ - HighRAM_HighVRAM (1): the fastest well suited for a RTX 3090 / RTX 4090 but consumes much more VRAM, adapted for fast shorter video
157
+ - HighRAM_LowVRAM (2): a bit slower, better suited for RTX 3070/3080/4070/4080 or for RTX 3090 / RTX 4090 with large pictures batches or long videos
158
+ - LowRAM_HighVRAM (3): adapted for RTX 3090 / RTX 4090 with limited RAM but at the cost of VRAM (shorter videos)
159
+ - LowRAM_LowVRAM (4): if you have little VRAM or want to generate longer videos
160
+ - VerylowRAM_LowVRAM (5): at least 24 GB of RAM and 10 GB of VRAM : if you don't have much it won't be fast but maybe it will work
161
 
162
+ Profile 2 (High RAM) and 4 (Low RAM)are the most recommended profiles since they are versatile (support for long videos for a slight performance cost).\
163
+ However, a safe approach is to start from profile 5 (default profile) and then go down progressively to profile 4 and then to profile 2 as long as the app remains responsive or doesn't trigger any out of memory error.
164
 
165
+ ### Other Models for the GPU Poor
166
 
167
+ - HuanyuanVideoGP: https://github.com/deepbeepmeep/HunyuanVideoGP :\
168
+ One of the best open source Text to Video generator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
+ - Hunyuan3D-2GP: https://github.com/deepbeepmeep/Hunyuan3D-2GP :\
171
+ A great image to 3D and text to 3D tool by the Tencent team. Thanks to mmgp it can run with less than 6 GB of VRAM
172
 
173
+ - FluxFillGP: https://github.com/deepbeepmeep/FluxFillGP :\
174
+ One of the best inpainting / outpainting tools based on Flux that can run with less than 12 GB of VRAM.
175
 
176
+ - Cosmos1GP: https://github.com/deepbeepmeep/Cosmos1GP :\
177
+ This application include two models: a text to world generator and a image / video to world (probably the best open source image to video generator).
178
 
179
+ - OminiControlGP: https://github.com/deepbeepmeep/OminiControlGP :\
180
+ A Flux derived application very powerful that can be used to transfer an object of your choice in a prompted scene. With mmgp you can run it with only 6 GB of VRAM.
181
 
182
+ - YuE GP: https://github.com/deepbeepmeep/YuEGP :\
183
+ A great song generator (instruments + singer's voice) based on prompted Lyrics and a genre description. Thanks to mmgp you can run it with less than 10 GB of VRAM without waiting forever.
184
 
185
 
 
 
gradio/i2v_14B_singleGPU.py CHANGED
@@ -24,8 +24,9 @@ wan_i2v_720P = None
24
 
25
 
26
  # Button Func
27
- def load_model(value):
28
  global wan_i2v_480P, wan_i2v_720P
 
29
 
30
  if value == '------':
31
  print("No model loaded")
@@ -52,8 +53,11 @@ def load_model(value):
52
  t5_fsdp=False,
53
  dit_fsdp=False,
54
  use_usp=False,
55
- )
 
56
  print("done", flush=True)
 
 
57
  return '720P'
58
 
59
  if value == '480P':
@@ -77,11 +81,16 @@ def load_model(value):
77
  t5_fsdp=False,
78
  dit_fsdp=False,
79
  use_usp=False,
 
80
  )
81
  print("done", flush=True)
 
 
 
82
  return '480P'
83
 
84
 
 
85
  def prompt_enc(prompt, img, tar_lang):
86
  print('prompt extend...')
87
  if img is None:
@@ -96,10 +105,12 @@ def prompt_enc(prompt, img, tar_lang):
96
  return prompt_output.prompt
97
 
98
 
99
- def i2v_generation(img2vid_prompt, img2vid_image, resolution, sd_steps,
100
  guide_scale, shift_scale, seed, n_prompt):
101
  # print(f"{img2vid_prompt},{resolution},{sd_steps},{guide_scale},{shift_scale},{seed},{n_prompt}")
102
-
 
 
103
  if resolution == '------':
104
  print(
105
  'Please specify at least one resolution ckpt dir or specify the resolution'
@@ -118,19 +129,19 @@ def i2v_generation(img2vid_prompt, img2vid_image, resolution, sd_steps,
118
  guide_scale=guide_scale,
119
  n_prompt=n_prompt,
120
  seed=seed,
121
- offload_model=True)
122
  else:
123
  global wan_i2v_480P
124
  video = wan_i2v_480P.generate(
125
  img2vid_prompt,
126
  img2vid_image,
127
  max_area=MAX_AREA_CONFIGS['480*832'],
128
- shift=shift_scale,
129
  sampling_steps=sd_steps,
130
  guide_scale=guide_scale,
131
  n_prompt=n_prompt,
132
  seed=seed,
133
- offload_model=True)
134
 
135
  cache_video(
136
  tensor=video[None],
@@ -169,6 +180,7 @@ def gradio_interface():
169
  )
170
  img2vid_prompt = gr.Textbox(
171
  label="Prompt",
 
172
  placeholder="Describe the video you want to generate",
173
  )
174
  tar_lang = gr.Radio(
@@ -262,6 +274,8 @@ def _parse_args():
262
  help="The prompt extend model to use.")
263
 
264
  args = parser.parse_args()
 
 
265
  assert args.ckpt_dir_720p is not None or args.ckpt_dir_480p is not None, "Please specify at least one checkpoint directory."
266
 
267
  return args
@@ -269,6 +283,12 @@ def _parse_args():
269
 
270
  if __name__ == '__main__':
271
  args = _parse_args()
 
 
 
 
 
 
272
 
273
  print("Step1: Init prompt_expander...", end='', flush=True)
274
  if args.prompt_extend_method == "dashscope":
 
24
 
25
 
26
  # Button Func
27
+ def load_i2v_model(value):
28
  global wan_i2v_480P, wan_i2v_720P
29
+ from mmgp import offload
30
 
31
  if value == '------':
32
  print("No model loaded")
 
53
  t5_fsdp=False,
54
  dit_fsdp=False,
55
  use_usp=False,
56
+ i2v720p= True
57
+ )
58
  print("done", flush=True)
59
+ pipe = {"transformer": wan_i2v_720P.model, "text_encoder" : wan_i2v_720P.text_encoder.model, "text_encoder_2": wan_i2v_720P.clip.model, "vae": wan_i2v_720P.vae.model } #
60
+ offload.profile(pipe, profile_no=4, budgets = {"transformer":100, "*":3000}, verboseLevel=2, compile="transformer", quantizeTransformer = False, pinnedMemory = False)
61
  return '720P'
62
 
63
  if value == '480P':
 
81
  t5_fsdp=False,
82
  dit_fsdp=False,
83
  use_usp=False,
84
+ i2v720p= False
85
  )
86
  print("done", flush=True)
87
+ pipe = {"transformer": wan_i2v_480P.model, "text_encoder" : wan_i2v_480P.text_encoder.model, "text_encoder_2": wan_i2v_480P.clip.model, "vae": wan_i2v_480P.vae.model } #
88
+ offload.profile(pipe, profile_no=4, budgets = {"model":100, "*":3000}, verboseLevel=2, compile="transformer" )
89
+
90
  return '480P'
91
 
92
 
93
+
94
  def prompt_enc(prompt, img, tar_lang):
95
  print('prompt extend...')
96
  if img is None:
 
105
  return prompt_output.prompt
106
 
107
 
108
+ def i2v_generation(img2vid_prompt, img2vid_image, res, sd_steps,
109
  guide_scale, shift_scale, seed, n_prompt):
110
  # print(f"{img2vid_prompt},{resolution},{sd_steps},{guide_scale},{shift_scale},{seed},{n_prompt}")
111
+ global resolution
112
+ from PIL import Image
113
+ img2vid_image = Image.open("d:\mammoth2.jpg")
114
  if resolution == '------':
115
  print(
116
  'Please specify at least one resolution ckpt dir or specify the resolution'
 
129
  guide_scale=guide_scale,
130
  n_prompt=n_prompt,
131
  seed=seed,
132
+ offload_model=False)
133
  else:
134
  global wan_i2v_480P
135
  video = wan_i2v_480P.generate(
136
  img2vid_prompt,
137
  img2vid_image,
138
  max_area=MAX_AREA_CONFIGS['480*832'],
139
+ shift=3.0, #shift_scale
140
  sampling_steps=sd_steps,
141
  guide_scale=guide_scale,
142
  n_prompt=n_prompt,
143
  seed=seed,
144
+ offload_model=False)
145
 
146
  cache_video(
147
  tensor=video[None],
 
180
  )
181
  img2vid_prompt = gr.Textbox(
182
  label="Prompt",
183
+ value="Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field.",
184
  placeholder="Describe the video you want to generate",
185
  )
186
  tar_lang = gr.Radio(
 
274
  help="The prompt extend model to use.")
275
 
276
  args = parser.parse_args()
277
+ args.ckpt_dir_720p = "../ckpts" # os.path.join("ckpt")
278
+ args.ckpt_dir_480p = "../ckpts" # os.path.join("ckpt")
279
  assert args.ckpt_dir_720p is not None or args.ckpt_dir_480p is not None, "Please specify at least one checkpoint directory."
280
 
281
  return args
 
283
 
284
  if __name__ == '__main__':
285
  args = _parse_args()
286
+ global resolution
287
+ # load_model('720P')
288
+ # resolution = '720P'
289
+ resolution = '480P'
290
+
291
+ load_model(resolution)
292
 
293
  print("Step1: Init prompt_expander...", end='', flush=True)
294
  if args.prompt_extend_method == "dashscope":
gradio/t2i_14B_singleGPU.py CHANGED
@@ -190,6 +190,7 @@ if __name__ == '__main__':
190
 
191
  print("Step2: Init 14B t2i model...", end='', flush=True)
192
  cfg = WAN_CONFIGS['t2i-14B']
 
193
  wan_t2i = wan.WanT2V(
194
  config=cfg,
195
  checkpoint_dir=args.ckpt_dir,
 
190
 
191
  print("Step2: Init 14B t2i model...", end='', flush=True)
192
  cfg = WAN_CONFIGS['t2i-14B']
193
+ # cfg = WAN_CONFIGS['t2v-1.3B']
194
  wan_t2i = wan.WanT2V(
195
  config=cfg,
196
  checkpoint_dir=args.ckpt_dir,
gradio/t2v_14B_singleGPU.py CHANGED
@@ -46,7 +46,7 @@ def t2v_generation(txt2vid_prompt, resolution, sd_steps, guide_scale,
46
  guide_scale=guide_scale,
47
  n_prompt=n_prompt,
48
  seed=seed,
49
- offload_model=True)
50
 
51
  cache_video(
52
  tensor=video[None],
@@ -177,28 +177,39 @@ if __name__ == '__main__':
177
  args = _parse_args()
178
 
179
  print("Step1: Init prompt_expander...", end='', flush=True)
180
- if args.prompt_extend_method == "dashscope":
181
- prompt_expander = DashScopePromptExpander(
182
- model_name=args.prompt_extend_model, is_vl=False)
183
- elif args.prompt_extend_method == "local_qwen":
184
- prompt_expander = QwenPromptExpander(
185
- model_name=args.prompt_extend_model, is_vl=False, device=0)
186
- else:
187
- raise NotImplementedError(
188
- f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
189
- print("done", flush=True)
 
 
 
190
 
191
  print("Step2: Init 14B t2v model...", end='', flush=True)
192
  cfg = WAN_CONFIGS['t2v-14B']
 
 
193
  wan_t2v = wan.WanT2V(
194
  config=cfg,
195
- checkpoint_dir=args.ckpt_dir,
196
  device_id=0,
197
  rank=0,
198
  t5_fsdp=False,
199
  dit_fsdp=False,
200
  use_usp=False,
201
  )
 
 
 
 
 
 
202
  print("done", flush=True)
203
 
204
  demo = gradio_interface()
 
46
  guide_scale=guide_scale,
47
  n_prompt=n_prompt,
48
  seed=seed,
49
+ offload_model=False)
50
 
51
  cache_video(
52
  tensor=video[None],
 
177
  args = _parse_args()
178
 
179
  print("Step1: Init prompt_expander...", end='', flush=True)
180
+ prompt_expander = None
181
+ # if args.prompt_extend_method == "dashscope":
182
+ # prompt_expander = DashScopePromptExpander(
183
+ # model_name=args.prompt_extend_model, is_vl=False)
184
+ # elif args.prompt_extend_method == "local_qwen":
185
+ # prompt_expander = QwenPromptExpander(
186
+ # model_name=args.prompt_extend_model, is_vl=False, device=0)
187
+ # else:
188
+ # raise NotImplementedError(
189
+ # f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
190
+ # print("done", flush=True)
191
+
192
+ from mmgp import offload
193
 
194
  print("Step2: Init 14B t2v model...", end='', flush=True)
195
  cfg = WAN_CONFIGS['t2v-14B']
196
+ # cfg = WAN_CONFIGS['t2v-1.3B']
197
+
198
  wan_t2v = wan.WanT2V(
199
  config=cfg,
200
+ checkpoint_dir="../ckpts",
201
  device_id=0,
202
  rank=0,
203
  t5_fsdp=False,
204
  dit_fsdp=False,
205
  use_usp=False,
206
  )
207
+
208
+ pipe = {"transformer": wan_t2v.model, "text_encoder" : wan_t2v.text_encoder.model, "vae": wan_t2v.vae.model } #
209
+ # offload.profile(pipe, profile_no=4, budgets = {"transformer":100, "*":3000}, verboseLevel=2, quantizeTransformer = False, compile = "transformer") #
210
+ offload.profile(pipe, profile_no=4, budgets = {"transformer":100, "*":3000}, verboseLevel=2, quantizeTransformer = False) #
211
+ # offload.profile(pipe, profile_no=4, budgets = {"transformer":3000, "*":3000}, verboseLevel=2, quantizeTransformer = False)
212
+
213
  print("done", flush=True)
214
 
215
  demo = gradio_interface()
gradio_server.py ADDED
@@ -0,0 +1,1275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import argparse
4
+ from mmgp import offload, safetensors2, profile_type
5
+ try:
6
+ import triton
7
+ except ImportError:
8
+ pass
9
+ from pathlib import Path
10
+ from datetime import datetime
11
+ import gradio as gr
12
+ import random
13
+ import json
14
+ import wan
15
+ from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS, SUPPORTED_SIZES
16
+ from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
17
+ from wan.utils.utils import cache_video
18
+ from wan.modules.attention import get_attention_modes
19
+ import torch
20
+ import gc
21
+
22
+ def _parse_args():
23
+ parser = argparse.ArgumentParser(
24
+ description="Generate a video from a text prompt or image using Gradio")
25
+ parser.add_argument(
26
+ "--ckpt_dir_720p",
27
+ type=str,
28
+ default=None,
29
+ help="The path to the checkpoint directory.")
30
+ parser.add_argument(
31
+ "--ckpt_dir_480p",
32
+ type=str,
33
+ default=None,
34
+ help="The path to the checkpoint directory.")
35
+ parser.add_argument(
36
+ "--prompt_extend_method",
37
+ type=str,
38
+ default="local_qwen",
39
+ choices=["dashscope", "local_qwen"],
40
+ help="The prompt extend method to use.")
41
+ parser.add_argument(
42
+ "--prompt_extend_model",
43
+ type=str,
44
+ default=None,
45
+ help="The prompt extend model to use.")
46
+
47
+ parser.add_argument(
48
+ "--quantize-transformer",
49
+ action="store_true",
50
+ help="On the fly 'transformer' quantization"
51
+ )
52
+
53
+
54
+ parser.add_argument(
55
+ "--lora-dir-i2v",
56
+ type=str,
57
+ default="loras_i2v",
58
+ help="Path to a directory that contains Loras for i2v"
59
+ )
60
+
61
+
62
+ parser.add_argument(
63
+ "--lora-dir",
64
+ type=str,
65
+ default="loras",
66
+ help="Path to a directory that contains Loras"
67
+ )
68
+
69
+
70
+ parser.add_argument(
71
+ "--lora-preset",
72
+ type=str,
73
+ default="",
74
+ help="Lora preset to preload"
75
+ )
76
+
77
+ parser.add_argument(
78
+ "--lora-preset-i2v",
79
+ type=str,
80
+ default="",
81
+ help="Lora preset to preload for i2v"
82
+ )
83
+
84
+ parser.add_argument(
85
+ "--profile",
86
+ type=str,
87
+ default=-1,
88
+ help="Profile No"
89
+ )
90
+
91
+ parser.add_argument(
92
+ "--verbose",
93
+ type=str,
94
+ default=1,
95
+ help="Verbose level"
96
+ )
97
+
98
+ parser.add_argument(
99
+ "--server-port",
100
+ type=str,
101
+ default=0,
102
+ help="Server port"
103
+ )
104
+
105
+ parser.add_argument(
106
+ "--server-name",
107
+ type=str,
108
+ default="",
109
+ help="Server name"
110
+ )
111
+
112
+ parser.add_argument(
113
+ "--open-browser",
114
+ action="store_true",
115
+ help="open browser"
116
+ )
117
+
118
+ parser.add_argument(
119
+ "--t2v",
120
+ action="store_true",
121
+ help="text to video mode"
122
+ )
123
+
124
+ parser.add_argument(
125
+ "--i2v",
126
+ action="store_true",
127
+ help="image to video mode"
128
+ )
129
+
130
+ parser.add_argument(
131
+ "--compile",
132
+ action="store_true",
133
+ help="Enable pytorch compilation"
134
+ )
135
+
136
+ # parser.add_argument(
137
+ # "--fast",
138
+ # action="store_true",
139
+ # help="use Fast model"
140
+ # )
141
+
142
+ # parser.add_argument(
143
+ # "--fastest",
144
+ # action="store_true",
145
+ # help="activate the best config"
146
+ # )
147
+
148
+ parser.add_argument(
149
+ "--attention",
150
+ type=str,
151
+ default="",
152
+ help="attention mode"
153
+ )
154
+
155
+ parser.add_argument(
156
+ "--vae-config",
157
+ type=str,
158
+ default="",
159
+ help="vae config mode"
160
+ )
161
+
162
+
163
+ args = parser.parse_args()
164
+ args.ckpt_dir_720p = "../ckpts" # os.path.join("ckpt")
165
+ args.ckpt_dir_480p = "../ckpts" # os.path.join("ckpt")
166
+ assert args.ckpt_dir_720p is not None or args.ckpt_dir_480p is not None, "Please specify at least one checkpoint directory."
167
+
168
+ return args
169
+
170
+ attention_modes_supported = get_attention_modes()
171
+
172
+ args = _parse_args()
173
+ args.flow_reverse = True
174
+
175
+
176
+ lock_ui_attention = False
177
+ lock_ui_transformer = False
178
+ lock_ui_compile = False
179
+
180
+
181
+ force_profile_no = int(args.profile)
182
+ verbose_level = int(args.verbose)
183
+ quantizeTransformer = args.quantize_transformer
184
+
185
+ transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_int8.safetensors"]
186
+ transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_720p_14B_quanto_int8.safetensors"]
187
+ text_encoder_choices = ["ckpts/models_t5_umt5-xxl-enc-bf16", "ckpts/models_t5_umt5-xxl-enc-quanto_int8.safetensors"]
188
+
189
+ server_config_filename = "gradio_config.json"
190
+
191
+ if not Path(server_config_filename).is_file():
192
+ server_config = {"attention_mode" : "auto",
193
+ "transformer_filename": transformer_choices_t2v[0],
194
+ "transformer_filename_i2v": transformer_choices_i2v[1], ########
195
+ "text_encoder_filename" : text_encoder_choices[1],
196
+ "compile" : "",
197
+ "default_ui": "t2v",
198
+ "vae_config": 0,
199
+ "profile" : profile_type.LowRAM_LowVRAM }
200
+
201
+ with open(server_config_filename, "w", encoding="utf-8") as writer:
202
+ writer.write(json.dumps(server_config))
203
+ else:
204
+ with open(server_config_filename, "r", encoding="utf-8") as reader:
205
+ text = reader.read()
206
+ server_config = json.loads(text)
207
+
208
+
209
+ transformer_filename_t2v = server_config["transformer_filename"]
210
+ transformer_filename_i2v = server_config.get("transformer_filename_i2v", transformer_choices_i2v[1]) ########
211
+
212
+ text_encoder_filename = server_config["text_encoder_filename"]
213
+ attention_mode = server_config["attention_mode"]
214
+ if len(args.attention)> 0:
215
+ if args.attention in ["auto", "sdpa", "sage", "sage2", "flash", "xformers"]:
216
+ attention_mode = args.attention
217
+ lock_ui_attention = True
218
+ else:
219
+ raise Exception(f"Unknown attention mode '{args.attention}'")
220
+
221
+ profile = force_profile_no if force_profile_no >=0 else server_config["profile"]
222
+ compile = server_config.get("compile", "")
223
+ vae_config = server_config.get("vae_config", 0)
224
+ if len(args.vae_config) > 0:
225
+ vae_config = int(args.vae_config)
226
+
227
+ default_ui = server_config.get("default_ui", "t2v")
228
+ use_image2video = default_ui != "t2v"
229
+ if args.t2v:
230
+ use_image2video = False
231
+ if args.i2v:
232
+ use_image2video = True
233
+
234
+ if use_image2video:
235
+ lora_dir =args.lora_dir_i2v
236
+ lora_preselected_preset = args.lora_preset_i2v
237
+ else:
238
+ lora_dir =args.lora_dir
239
+ lora_preselected_preset = args.lora_preset
240
+
241
+ default_tea_cache = 0
242
+ # if args.fast : #or args.fastest
243
+ # transformer_filename_t2v = transformer_choices_t2v[2]
244
+ # attention_mode="sage2" if "sage2" in attention_modes_supported else "sage"
245
+ # default_tea_cache = 0.15
246
+ # lock_ui_attention = True
247
+ # lock_ui_transformer = True
248
+
249
+ if args.compile: #args.fastest or
250
+ compile="transformer"
251
+ lock_ui_compile = True
252
+
253
+
254
+ #attention_mode="sage"
255
+ #attention_mode="sage2"
256
+ #attention_mode="flash"
257
+ #attention_mode="sdpa"
258
+ #attention_mode="xformers"
259
+ # compile = "transformer"
260
+
261
+ def download_models(transformer_filename, text_encoder_filename):
262
+ def computeList(filename):
263
+ pos = filename.rfind("/")
264
+ filename = filename[pos+1:]
265
+ return [filename]
266
+
267
+ from huggingface_hub import hf_hub_download, snapshot_download
268
+ repoId = "DeepBeepMeep/Wan2.1"
269
+ sourceFolderList = ["xlm-roberta-large", "", ]
270
+ fileList = [ [], ["Wan2.1_VAE.pth", "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" ] + computeList(text_encoder_filename) + computeList(transformer_filename) ]
271
+ targetRoot = "ckpts/"
272
+ for sourceFolder, files in zip(sourceFolderList,fileList ):
273
+ if len(files)==0:
274
+ if not Path(targetRoot + sourceFolder).exists():
275
+ snapshot_download(repo_id=repoId, allow_patterns=sourceFolder +"/*", local_dir= targetRoot)
276
+ else:
277
+ for onefile in files:
278
+ if len(sourceFolder) > 0:
279
+ if not os.path.isfile(targetRoot + sourceFolder + "/" + onefile ):
280
+ hf_hub_download(repo_id=repoId, filename=onefile, local_dir = targetRoot, subfolder=sourceFolder)
281
+ else:
282
+ if not os.path.isfile(targetRoot + onefile ):
283
+ hf_hub_download(repo_id=repoId, filename=onefile, local_dir = targetRoot)
284
+
285
+
286
+ offload.default_verboseLevel = verbose_level
287
+
288
+ download_models(transformer_filename_i2v if use_image2video else transformer_filename_t2v, text_encoder_filename)
289
+
290
+ def sanitize_file_name(file_name):
291
+ return file_name.replace("/","").replace("\\","").replace(":","").replace("|","").replace("?","").replace("<","").replace(">","").replace("\"","")
292
+
293
+ def extract_preset(lset_name, loras):
294
+ lset_name = sanitize_file_name(lset_name)
295
+ if not lset_name.endswith(".lset"):
296
+ lset_name_filename = os.path.join(lora_dir, lset_name + ".lset" )
297
+ else:
298
+ lset_name_filename = os.path.join(lora_dir, lset_name )
299
+
300
+ if not os.path.isfile(lset_name_filename):
301
+ raise gr.Error(f"Preset '{lset_name}' not found ")
302
+
303
+ with open(lset_name_filename, "r", encoding="utf-8") as reader:
304
+ text = reader.read()
305
+ lset = json.loads(text)
306
+
307
+ loras_choices_files = lset["loras"]
308
+ loras_choices = []
309
+ missing_loras = []
310
+ for lora_file in loras_choices_files:
311
+ loras_choice_no = loras.index(os.path.join(lora_dir, lora_file))
312
+ if loras_choice_no < 0:
313
+ missing_loras.append(lora_file)
314
+ else:
315
+ loras_choices.append(str(loras_choice_no))
316
+
317
+ if len(missing_loras) > 0:
318
+ raise gr.Error(f"Unable to apply Lora preset '{lset_name} because the following Loras files are missing: {missing_loras}")
319
+
320
+ loras_mult_choices = lset["loras_mult"]
321
+ return loras_choices, loras_mult_choices
322
+
323
+ def setup_loras(pipe, lora_dir, lora_preselected_preset, split_linear_modules_map = None):
324
+ loras =[]
325
+ loras_names = []
326
+ default_loras_choices = []
327
+ default_loras_multis_str = ""
328
+ loras_presets = []
329
+
330
+ from pathlib import Path
331
+
332
+ if lora_dir != None :
333
+ if not os.path.isdir(lora_dir):
334
+ raise Exception("--lora-dir should be a path to a directory that contains Loras")
335
+
336
+ default_lora_preset = ""
337
+
338
+ if lora_dir != None:
339
+ import glob
340
+ dir_loras = glob.glob( os.path.join(lora_dir , "*.sft") ) + glob.glob( os.path.join(lora_dir , "*.safetensors") )
341
+ dir_loras.sort()
342
+ loras += [element for element in dir_loras if element not in loras ]
343
+
344
+ dir_presets = glob.glob( os.path.join(lora_dir , "*.lset") )
345
+ dir_presets.sort()
346
+ loras_presets = [ Path(Path(file_path).parts[-1]).stem for file_path in dir_presets]
347
+
348
+ if len(loras) > 0:
349
+ loras_names = [ Path(lora).stem for lora in loras ]
350
+ offload.load_loras_into_model(pipe.transformer, loras, activate_all_loras=False, split_linear_modules_map = split_linear_modules_map) #lora_multiplier,
351
+
352
+ if len(lora_preselected_preset) > 0:
353
+ if not os.path.isfile(os.path.join(lora_dir, lora_preselected_preset + ".lset")):
354
+ raise Exception(f"Unknown preset '{lora_preselected_preset}'")
355
+ default_lora_preset = lora_preselected_preset
356
+ default_loras_choices, default_loras_multis_str= extract_preset(default_lora_preset, loras)
357
+
358
+ return loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset, loras_presets
359
+
360
+
361
+ def load_t2v_model(model_filename, value):
362
+
363
+ cfg = WAN_CONFIGS['t2v-14B']
364
+ # cfg = WAN_CONFIGS['t2v-1.3B']
365
+ print("load t2v model...")
366
+
367
+ wan_model = wan.WanT2V(
368
+ config=cfg,
369
+ checkpoint_dir="ckpts",
370
+ device_id=0,
371
+ rank=0,
372
+ t5_fsdp=False,
373
+ dit_fsdp=False,
374
+ use_usp=False,
375
+ model_filename=model_filename,
376
+ text_encoder_filename= text_encoder_filename
377
+ )
378
+
379
+ pipe = {"transformer": wan_model.model, "text_encoder" : wan_model.text_encoder.model, "vae": wan_model.vae.model }
380
+
381
+ return wan_model, pipe
382
+
383
+ def load_i2v_model(model_filename, value):
384
+
385
+
386
+ if value == '720P':
387
+ print("load 14B-720P i2v model...")
388
+ cfg = WAN_CONFIGS['i2v-14B']
389
+ wan_model = wan.WanI2V(
390
+ config=cfg,
391
+ checkpoint_dir="ckpts",
392
+ device_id=0,
393
+ rank=0,
394
+ t5_fsdp=False,
395
+ dit_fsdp=False,
396
+ use_usp=False,
397
+ i2v720p= True,
398
+ model_filename=model_filename,
399
+ text_encoder_filename=text_encoder_filename
400
+ )
401
+ pipe = {"transformer": wan_model.model, "text_encoder" : wan_model.text_encoder.model, "text_encoder_2": wan_model.clip.model, "vae": wan_model.vae.model } #
402
+
403
+ if value == '480P':
404
+ print("load 14B-480P i2v model...")
405
+ cfg = WAN_CONFIGS['i2v-14B']
406
+ wan_model = wan.WanI2V(
407
+ config=cfg,
408
+ checkpoint_dir="ckpts",
409
+ device_id=0,
410
+ rank=0,
411
+ t5_fsdp=False,
412
+ dit_fsdp=False,
413
+ use_usp=False,
414
+ i2v720p= False,
415
+ model_filename=model_filename,
416
+ text_encoder_filename=text_encoder_filename
417
+
418
+ )
419
+ pipe = {"transformer": wan_model.model, "text_encoder" : wan_model.text_encoder.model, "text_encoder_2": wan_model.clip.model, "vae": wan_model.vae.model } #
420
+
421
+ return wan_model, pipe
422
+
423
+ def load_models(i2v, lora_dir, lora_preselected_preset ):
424
+ download_models(transformer_filename_i2v if i2v else transformer_filename_t2v, text_encoder_filename)
425
+
426
+ if i2v:
427
+ res720P= "720p" in transformer_filename_i2v
428
+ wan_model, pipe = load_i2v_model(transformer_filename_i2v,"720P" if res720P else "480P")
429
+ else:
430
+ wan_model, pipe = load_t2v_model(transformer_filename_t2v,"")
431
+
432
+ kwargs = { "extraModelsToQuantize": None}
433
+ if profile == 2 or profile == 4:
434
+ kwargs["budgets"] = { "transformer" : 100, "*" : 3000 }
435
+
436
+ loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset, loras_presets = setup_loras(pipe, lora_dir, lora_preselected_preset, None)
437
+ offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, **kwargs)
438
+
439
+
440
+ return wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset, loras_presets
441
+
442
+ wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset, loras_presets = load_models(use_image2video, lora_dir, lora_preselected_preset )
443
+ gen_in_progress = False
444
+
445
+ def get_auto_attention():
446
+ for attn in ["sage2","sage","sdpa"]:
447
+ if attn in attention_modes_supported:
448
+ return attn
449
+ return "sdpa"
450
+
451
+ def get_default_flow(model_filename):
452
+ return 3.0 if "480p" in model_filename else 5.0
453
+
454
+ def generate_header(model_filename, compile, attention_mode):
455
+ header = "<H2 ALIGN=CENTER><SPAN> ----------------- "
456
+
457
+ if "image" in model_filename:
458
+ model_name = "Wan2.1 image2video"
459
+ model_name += "720p" if "720p" in model_filename else "480"
460
+ else:
461
+ model_name = "Wan2.1 text2video"
462
+ model_name += "14B" if "14B" in model_filename else "1.3B"
463
+
464
+ header += model_name
465
+ header += " (attention mode: " + (attention_mode if attention_mode!="auto" else "auto/" + get_auto_attention() )
466
+ if attention_mode not in attention_modes_supported:
467
+ header += " -NOT INSTALLED-"
468
+
469
+ if compile:
470
+ header += ", pytorch compilation ON"
471
+ header += ") -----------------</SPAN></H2>"
472
+
473
+ return header
474
+
475
+ def apply_changes( state,
476
+ transformer_t2v_choice,
477
+ transformer_i2v_choice,
478
+ text_encoder_choice,
479
+ attention_choice,
480
+ compile_choice,
481
+ profile_choice,
482
+ vae_config_choice,
483
+ default_ui_choice ="t2v",
484
+ ):
485
+
486
+ if gen_in_progress:
487
+ yield "<DIV ALIGN=CENTER>Unable to change config when a generation is in progress</DIV>"
488
+ return
489
+ global offloadobj, wan_model, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset, loras_presets
490
+ server_config = {"attention_mode" : attention_choice,
491
+ "transformer_filename": transformer_choices_t2v[transformer_t2v_choice],
492
+ "transformer_filename_i2v": transformer_choices_i2v[transformer_i2v_choice], ##########
493
+ "text_encoder_filename" : text_encoder_choices[text_encoder_choice],
494
+ "compile" : compile_choice,
495
+ "profile" : profile_choice,
496
+ "vae_config" : vae_config_choice,
497
+ "default_ui" : default_ui_choice,
498
+ }
499
+
500
+ if Path(server_config_filename).is_file():
501
+ with open(server_config_filename, "r", encoding="utf-8") as reader:
502
+ text = reader.read()
503
+ old_server_config = json.loads(text)
504
+ if lock_ui_transformer:
505
+ server_config["transformer_filename"] = old_server_config["transformer_filename"]
506
+ server_config["transformer_filename_i2v"] = old_server_config["transformer_filename_i2v"]
507
+ if lock_ui_attention:
508
+ server_config["attention_mode"] = old_server_config["attention_mode"]
509
+ if lock_ui_compile:
510
+ server_config["compile"] = old_server_config["compile"]
511
+
512
+ with open(server_config_filename, "w", encoding="utf-8") as writer:
513
+ writer.write(json.dumps(server_config))
514
+
515
+ changes = []
516
+ for k, v in server_config.items():
517
+ v_old = old_server_config.get(k, None)
518
+ if v != v_old:
519
+ changes.append(k)
520
+
521
+ state["config_changes"] = changes
522
+ state["config_new"] = server_config
523
+ state["config_old"] = old_server_config
524
+
525
+ global attention_mode, profile, compile, transformer_filename_t2v, transformer_filename_i2v, text_encoder_filename, vae_config
526
+ attention_mode = server_config["attention_mode"]
527
+ profile = server_config["profile"]
528
+ compile = server_config["compile"]
529
+ transformer_filename_t2v = server_config["transformer_filename"]
530
+ transformer_filename_i2v = server_config["transformer_filename_i2v"]
531
+ text_encoder_filename = server_config["text_encoder_filename"]
532
+ vae_config = server_config["vae_config"]
533
+
534
+ if all(change in ["attention_mode", "vae_config", "default_ui"] for change in changes ):
535
+ if "attention_mode" in changes:
536
+ pass
537
+
538
+ else:
539
+ wan_model = None
540
+ offloadobj.release()
541
+ offloadobj = None
542
+ yield "<DIV ALIGN=CENTER>Please wait while the new configuration is being applied</DIV>"
543
+
544
+ wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset, loras_presets = load_models(use_image2video, lora_dir, lora_preselected_preset )
545
+
546
+
547
+ yield "<DIV ALIGN=CENTER>The new configuration has been succesfully applied</DIV>"
548
+
549
+ # return "<DIV ALIGN=CENTER>New Config file created. Please restart the Gradio Server</DIV>"
550
+
551
+ def update_defaults(state, num_inference_steps,flow_shift):
552
+ if "config_changes" not in state:
553
+ return get_default_flow("")
554
+ changes = state["config_changes"]
555
+ server_config = state["config_new"]
556
+ old_server_config = state["config_old"]
557
+
558
+ if use_image2video:
559
+ old_is_14B = "14B" in server_config["transformer_filename"]
560
+ new_is_14B = "14B" in old_server_config["transformer_filename"]
561
+
562
+ trans_file = server_config["transformer_filename"]
563
+ # if old_is_14B != new_is_14B:
564
+ # num_inference_steps, flow_shift = get_default_flow(trans_file)
565
+ else:
566
+ old_is_720P = "720P" in server_config["transformer_filename_i2v"]
567
+ new_is_720P = "720P" in old_server_config["transformer_filename_i2v"]
568
+ trans_file = server_config["transformer_filename_i2v"]
569
+ if old_is_720P != new_is_720P:
570
+ num_inference_steps, flow_shift = get_default_flow(trans_file)
571
+
572
+ header = generate_header(trans_file, server_config["compile"], server_config["attention_mode"] )
573
+ return num_inference_steps, flow_shift, header
574
+
575
+
576
+ from moviepy.editor import ImageSequenceClip
577
+ import numpy as np
578
+
579
+ def save_video(final_frames, output_path, fps=24):
580
+ assert final_frames.ndim == 4 and final_frames.shape[3] == 3, f"invalid shape: {final_frames} (need t h w c)"
581
+ if final_frames.dtype != np.uint8:
582
+ final_frames = (final_frames * 255).astype(np.uint8)
583
+ ImageSequenceClip(list(final_frames), fps=fps).write_videofile(output_path, verbose= False, logger = None)
584
+
585
+ def build_callback(state, pipe, progress, status, num_inference_steps):
586
+ def callback(step_idx, latents):
587
+ step_idx += 1
588
+ if state.get("abort", False):
589
+ # pipe._interrupt = True
590
+ status_msg = status + " - Aborting"
591
+ elif step_idx == num_inference_steps:
592
+ status_msg = status + " - VAE Decoding"
593
+ else:
594
+ status_msg = status + " - Denoising"
595
+
596
+ progress( (step_idx , num_inference_steps) , status_msg , num_inference_steps)
597
+
598
+ return callback
599
+
600
+ def abort_generation(state):
601
+ if "in_progress" in state:
602
+ state["abort"] = True
603
+ wan_model._interrupt= True
604
+ return gr.Button(interactive= False)
605
+ else:
606
+ return gr.Button(interactive= True)
607
+
608
+ def refresh_gallery(state):
609
+ file_list = state.get("file_list", None)
610
+ return file_list
611
+
612
+ def finalize_gallery(state):
613
+ choice = 0
614
+ if "in_progress" in state:
615
+ del state["in_progress"]
616
+ choice = state.get("selected",0)
617
+
618
+ time.sleep(0.2)
619
+ global gen_in_progress
620
+ gen_in_progress = False
621
+ return gr.Gallery(selected_index=choice), gr.Button(interactive= True)
622
+
623
+ def select_video(state , event_data: gr.EventData):
624
+ data= event_data._data
625
+ if data!=None:
626
+ state["selected"] = data.get("index",0)
627
+ return
628
+
629
+ def expand_slist(slist, num_inference_steps ):
630
+ new_slist= []
631
+ inc = len(slist) / num_inference_steps
632
+ pos = 0
633
+ for i in range(num_inference_steps):
634
+ new_slist.append(slist[ int(pos)])
635
+ pos += inc
636
+ return new_slist
637
+
638
+
639
+ def generate_video(
640
+ prompt,
641
+ negative_prompt,
642
+ resolution,
643
+ video_length,
644
+ seed,
645
+ num_inference_steps,
646
+ guidance_scale,
647
+ flow_shift,
648
+ embedded_guidance_scale,
649
+ repeat_generation,
650
+ tea_cache,
651
+ loras_choices,
652
+ loras_mult_choices,
653
+ image_to_continue,
654
+ video_to_continue,
655
+ max_frames,
656
+ RIFLEx_setting,
657
+ state,
658
+ progress=gr.Progress() #track_tqdm= True
659
+
660
+ ):
661
+
662
+ from PIL import Image
663
+ import numpy as np
664
+ import tempfile
665
+
666
+
667
+ if wan_model == None:
668
+ raise gr.Error("Unable to generate a Video while a new configuration is being applied.")
669
+ if attention_mode == "auto":
670
+ attn = get_auto_attention()
671
+ elif attention_mode in attention_modes_supported:
672
+ attn = attention_mode
673
+ else:
674
+ raise gr.Error(f"You have selected attention mode '{attention_mode}'. However it is not installed on your system. You should either install it or switch to the default 'sdpa' attention.")
675
+
676
+ width, height = resolution.split("x")
677
+ width, height = int(width), int(height)
678
+
679
+
680
+ if use_image2video:
681
+ if "480p" in transformer_filename_i2v and width * height > 848*480:
682
+ raise gr.Error("You must use the 720P image to video model to generate videos with a resolution equivalent to 720P")
683
+
684
+ resolution = str(width) + "*" + str(height)
685
+ if resolution not in ['720*1280', '1280*720', '480*832', '832*480']:
686
+ raise gr.Error(f"Resolution {resolution} not supported by image 2 video")
687
+
688
+
689
+ else:
690
+ if "1.3B" in transformer_filename_t2v and width * height > 848*480:
691
+ raise gr.Error("You must use the 14B text to video model to generate videos with a resolution equivalent to 720P")
692
+
693
+ offload.shared_state["_vae"] = vae_config
694
+ offload.shared_state["_vae_threshold"] = 0.9* torch.cuda.get_device_properties(0).total_memory
695
+
696
+ offload.shared_state["_attention"] = attn
697
+
698
+ global gen_in_progress
699
+ gen_in_progress = True
700
+ temp_filename = None
701
+ if use_image2video:
702
+ if image_to_continue is not None:
703
+ pass
704
+
705
+ elif video_to_continue != None and len(video_to_continue) >0 :
706
+ input_image_or_video_path = video_to_continue
707
+ # pipeline.num_input_frames = max_frames
708
+ # pipeline.max_frames = max_frames
709
+ else:
710
+ return
711
+ else:
712
+ input_image_or_video_path = None
713
+
714
+
715
+ if len(loras) > 0:
716
+ def is_float(element: any) -> bool:
717
+ if element is None:
718
+ return False
719
+ try:
720
+ float(element)
721
+ return True
722
+ except ValueError:
723
+ return False
724
+ list_mult_choices_nums = []
725
+ if len(loras_mult_choices) > 0:
726
+ list_mult_choices_str = loras_mult_choices.split(" ")
727
+ for i, mult in enumerate(list_mult_choices_str):
728
+ mult = mult.strip()
729
+ if "," in mult:
730
+ multlist = mult.split(",")
731
+ slist = []
732
+ for smult in multlist:
733
+ if not is_float(smult):
734
+ raise gr.Error(f"Lora sub value no {i+1} ({smult}) in Multiplier definition '{multlist}' is invalid")
735
+ slist.append(float(smult))
736
+ slist = expand_slist(slist, num_inference_steps )
737
+ list_mult_choices_nums.append(slist)
738
+ else:
739
+ if not is_float(mult):
740
+ raise gr.Error(f"Lora Multiplier no {i+1} ({mult}) is invalid")
741
+ list_mult_choices_nums.append(float(mult))
742
+ if len(list_mult_choices_nums ) < len(loras_choices):
743
+ list_mult_choices_nums += [1.0] * ( len(loras_choices) - len(list_mult_choices_nums ) )
744
+
745
+ offload.activate_loras(wan_model.model, loras_choices, list_mult_choices_nums)
746
+
747
+ seed = None if seed == -1 else seed
748
+ # negative_prompt = "" # not applicable in the inference
749
+
750
+ if "abort" in state:
751
+ del state["abort"]
752
+ state["in_progress"] = True
753
+ state["selected"] = 0
754
+
755
+ enable_riflex = RIFLEx_setting == 0 and video_length > (5* 24) or RIFLEx_setting == 1
756
+ # VAE Tiling
757
+ device_mem_capacity = torch.cuda.get_device_properties(0).total_memory / 1048576
758
+
759
+
760
+ # TeaCache
761
+ trans = wan_model.model
762
+ trans.enable_teacache = tea_cache > 0
763
+
764
+ import random
765
+ if seed == None or seed <0:
766
+ seed = random.randint(0, 999999999)
767
+
768
+ file_list = []
769
+ state["file_list"] = file_list
770
+ from einops import rearrange
771
+ save_path = os.path.join(os.getcwd(), "gradio_outputs")
772
+ os.makedirs(save_path, exist_ok=True)
773
+ prompts = prompt.replace("\r", "").split("\n")
774
+ video_no = 0
775
+ total_video = repeat_generation * len(prompts)
776
+ abort = False
777
+ start_time = time.time()
778
+ for prompt in prompts:
779
+ for _ in range(repeat_generation):
780
+ if abort:
781
+ break
782
+
783
+ if trans.enable_teacache:
784
+ trans.num_steps = num_inference_steps
785
+ trans.cnt = 0
786
+ trans.rel_l1_thresh = tea_cache #0.15 # 0.1 for 1.6x speedup, 0.15 for 2.1x speedup
787
+ trans.accumulated_rel_l1_distance = 0
788
+ trans.previous_modulated_input = None
789
+ trans.previous_residual = None
790
+
791
+ video_no += 1
792
+ status = f"Video {video_no}/{total_video}"
793
+ progress(0, desc=status + " - Encoding Prompt" )
794
+
795
+ callback = build_callback(state, trans, progress, status, num_inference_steps)
796
+
797
+
798
+ gc.collect()
799
+ torch.cuda.empty_cache()
800
+ try:
801
+ if use_image2video:
802
+ samples = wan_model.generate(
803
+ prompt,
804
+ image_to_continue,
805
+ frame_num=(video_length // 4)* 4 + 1,
806
+ max_area=MAX_AREA_CONFIGS[resolution],
807
+ shift=flow_shift,
808
+ sampling_steps=num_inference_steps,
809
+ guide_scale=guidance_scale,
810
+ n_prompt=negative_prompt,
811
+ seed=seed,
812
+ offload_model=False,
813
+ callback=callback
814
+ )
815
+
816
+ else:
817
+ samples = wan_model.generate(
818
+ prompt,
819
+ frame_num=(video_length // 4)* 4 + 1,
820
+ size=(width, height),
821
+ shift=flow_shift,
822
+ sampling_steps=num_inference_steps,
823
+ guide_scale=guidance_scale,
824
+ n_prompt=negative_prompt,
825
+ seed=seed,
826
+ offload_model=False,
827
+ callback=callback
828
+ )
829
+ except:
830
+ gen_in_progress = False
831
+ if temp_filename!= None and os.path.isfile(temp_filename):
832
+ os.remove(temp_filename)
833
+ offload.last_offload_obj.unload_all()
834
+ # if compile:
835
+ # cache_size = torch._dynamo.config.cache_size_limit
836
+ # torch.compiler.reset()
837
+ # torch._dynamo.config.cache_size_limit = cache_size
838
+
839
+ gc.collect()
840
+ torch.cuda.empty_cache()
841
+ raise gr.Error("The generation of the video has encountered an error: it is likely that you have unsufficient VRAM and you should therefore reduce the video resolution or its number of frames.")
842
+
843
+
844
+ if samples != None:
845
+ samples = samples.to("cpu")
846
+ offload.last_offload_obj.unload_all()
847
+ gc.collect()
848
+ torch.cuda.empty_cache()
849
+
850
+ if samples == None:
851
+ end_time = time.time()
852
+ abort = True
853
+ yield f"Video generation was aborted. Total Generation Time: {end_time-start_time:.1f}s"
854
+ else:
855
+ sample = samples.cpu()
856
+ # video = rearrange(sample.cpu().numpy(), "c t h w -> t h w c")
857
+
858
+ time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss")
859
+ file_name = f"{time_flag}_seed{seed}_{prompt[:100].replace('/','').strip()}.mp4".replace(':',' ').replace('\\',' ')
860
+ video_path = os.path.join(os.getcwd(), "gradio_outputs", file_name)
861
+ cache_video(
862
+ tensor=sample[None],
863
+ save_file=video_path,
864
+ fps=16,
865
+ nrow=1,
866
+ normalize=True,
867
+ value_range=(-1, 1))
868
+
869
+ print(f"New video saved to Path: "+video_path)
870
+ file_list.append(video_path)
871
+ if video_no < total_video:
872
+ yield status
873
+ else:
874
+ end_time = time.time()
875
+ yield f"Total Generation Time: {end_time-start_time:.1f}s"
876
+ seed += 1
877
+
878
+ if temp_filename!= None and os.path.isfile(temp_filename):
879
+ os.remove(temp_filename)
880
+ gen_in_progress = False
881
+
882
+ new_preset_msg = "Enter a Name for a Lora Preset or Choose One Above"
883
+
884
+ def save_lset(lset_name, loras_choices, loras_mult_choices):
885
+ global loras_presets
886
+
887
+ if len(lset_name) == 0 or lset_name== new_preset_msg:
888
+ gr.Info("Please enter a name for the preset")
889
+ lset_choices =[("Please enter a name for a Lora Preset","")]
890
+ else:
891
+ lset_name = sanitize_file_name(lset_name)
892
+
893
+ loras_choices_files = [ Path(loras[int(choice_no)]).parts[-1] for choice_no in loras_choices ]
894
+ lset = {"loras" : loras_choices_files, "loras_mult" : loras_mult_choices}
895
+ lset_name_filename = lset_name + ".lset"
896
+ full_lset_name_filename = os.path.join(lora_dir, lset_name_filename)
897
+
898
+ with open(full_lset_name_filename, "w", encoding="utf-8") as writer:
899
+ writer.write(json.dumps(lset))
900
+
901
+ if lset_name in loras_presets:
902
+ gr.Info(f"Lora Preset '{lset_name}' has been updated")
903
+ else:
904
+ gr.Info(f"Lora Preset '{lset_name}' has been created")
905
+ loras_presets.append(Path(Path(lset_name_filename).parts[-1]).stem )
906
+ lset_choices = [ ( preset, preset) for preset in loras_presets ]
907
+ lset_choices.append( (new_preset_msg, ""))
908
+
909
+ return gr.Dropdown(choices=lset_choices, value= lset_name)
910
+
911
+ def delete_lset(lset_name):
912
+ global loras_presets
913
+ lset_name_filename = os.path.join(lora_dir, sanitize_file_name(lset_name) + ".lset" )
914
+ if len(lset_name) > 0 and lset_name != new_preset_msg:
915
+ if not os.path.isfile(lset_name_filename):
916
+ raise gr.Error(f"Preset '{lset_name}' not found ")
917
+ os.remove(lset_name_filename)
918
+ pos = loras_presets.index(lset_name)
919
+ gr.Info(f"Lora Preset '{lset_name}' has been deleted")
920
+ loras_presets.remove(lset_name)
921
+ else:
922
+ pos = len(loras_presets)
923
+ gr.Info(f"Choose a Preset to delete")
924
+
925
+ lset_choices = [ (preset, preset) for preset in loras_presets]
926
+ lset_choices.append((new_preset_msg, ""))
927
+ return gr.Dropdown(choices=lset_choices, value= lset_choices[pos][1])
928
+
929
+ def apply_lset(lset_name, loras_choices, loras_mult_choices):
930
+
931
+ if len(lset_name) == 0 or lset_name== new_preset_msg:
932
+ gr.Info("Please choose a preset in the list or create one")
933
+ else:
934
+ loras_choices, loras_mult_choices= extract_preset(lset_name, loras)
935
+ gr.Info(f"Lora Preset '{lset_name}' has been applied")
936
+
937
+ return loras_choices, loras_mult_choices
938
+
939
+ def create_demo():
940
+
941
+ default_inference_steps = 30
942
+ default_flow_shift = get_default_flow(transformer_filename_i2v if use_image2video else transformer_filename_t2v)
943
+
944
+ with gr.Blocks() as demo:
945
+ state = gr.State({})
946
+
947
+ if use_image2video:
948
+ gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v1 - AI Image To Video Generator (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A> / <A HREF='https://github.com/Wan-Video/Wan2.1'>Original by Alibaba</A>)</H1></div>")
949
+ else:
950
+ gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v1 - AI Text To Video Generator (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A> / <A HREF='https://github.com/Wan-Video/Wan2.1'>Original by Alibaba</A>)</H1></div>")
951
+
952
+ gr.Markdown("<FONT SIZE=3>With this first release of Wan 2.1GP by <B>DeepBeepMeep</B> the VRAM requirements have been divided by more than 2 with no quality loss</FONT>")
953
+
954
+ if use_image2video and False:
955
+ pass
956
+ else:
957
+ gr.Markdown("The VRAM requirements will depend greatly of the resolution and the duration of the video, for instance : 24 GB of VRAM (RTX 3090 / RTX 4090), the limits are as follows:")
958
+ gr.Markdown("- 848 x 480 with a 14B model: 80 frames (5s) : 8 GB of VRAM")
959
+ gr.Markdown("- 848 x 480 with the 1.3B model: 80 frames (5s) : 5 GB of VRAM")
960
+ gr.Markdown("- 1280 x 720 with a 14B model: 192 frames (8s): 11 GB of VRAM")
961
+ gr.Markdown("Note that the VAE stages (encoding / decoding at image2video ) or just the decoding at text2video will create a temporary VRAM peak (up to 12GB for 420P and 22 GB for 720P)")
962
+
963
+ gr.Markdown("Please note that if your turn on compilation, the first generation step of the first video generation will be slow due to the compilation. Therefore all your tests should be done with compilation turned off.")
964
+
965
+
966
+ # css = """<STYLE>
967
+ # h2 { width: 100%; text-align: center; border-bottom: 1px solid #000; line-height: 0.1em; margin: 10px 0 20px; }
968
+ # h2 span {background:#fff; padding:0 10px; }</STYLE>"""
969
+ # gr.HTML(css)
970
+
971
+ header = gr.Markdown(generate_header(transformer_filename_i2v if use_image2video else transformer_filename_t2v, compile, attention_mode) )
972
+
973
+ with gr.Accordion("Video Engine Configuration - click here to change it", open = False):
974
+ gr.Markdown("For the changes to be effective you will need to restart the gradio_server. Some choices below may be locked if the app has been launched by specifying a config preset.")
975
+
976
+ with gr.Column():
977
+ index = transformer_choices_t2v.index(transformer_filename_t2v)
978
+ index = 0 if index ==0 else index
979
+ transformer_t2v_choice = gr.Dropdown(
980
+ choices=[
981
+ ("WAN 2.1 1.3B Text to Video 16 bits - the small model for fast generations with low VRAM requirements", 0),
982
+ ("WAN 2.1 14B Text to Video 16 bits - the default engine in its original glory, offers a slightly better image quality but slower and requires more RAM", 1),
983
+ ("WAN 2.1 14B Text to Video quantized to 8 bits (recommended) - the default engine but quantized", 2),
984
+ ],
985
+ value= index,
986
+ label="Transformer model for Text to Video",
987
+ interactive= not lock_ui_transformer,
988
+ visible=not use_image2video
989
+ )
990
+
991
+ index = transformer_choices_i2v.index(transformer_filename_i2v)
992
+ index = 0 if index ==0 else index
993
+ transformer_i2v_choice = gr.Dropdown(
994
+ choices=[
995
+ ("WAN 2.1 - 480p 14B Image to Video 16 bits - the default engine in its original glory, offers a slightly better image quality but slower and requires more RAM", 0),
996
+ ("WAN 2.1 - 480p 14B Image to Video quantized to 8 bits (recommended) - the default engine but quantized", 1),
997
+ ("WAN 2.1 - 720p 14B Image to Video 16 bits - the default engine in its original glory, offers a slightly better image quality but slower and requires more RAM", 2),
998
+ ("WAN 2.1 - 720p 14B Image to Video quantized to 8 bits (recommended) - the default engine but quantized", 3),
999
+ ],
1000
+ value= index,
1001
+ label="Transformer model for Image to Video",
1002
+ interactive= not lock_ui_transformer,
1003
+ visible = use_image2video, ###############
1004
+ )
1005
+
1006
+ index = text_encoder_choices.index(text_encoder_filename)
1007
+ index = 0 if index ==0 else index
1008
+
1009
+ text_encoder_choice = gr.Dropdown(
1010
+ choices=[
1011
+ ("UMT5 XXL 16 bits - unquantized text encoder, better quality uses more RAM", 0),
1012
+ ("UMT5 XXL quantized to 8 bits - quantized text encoder, slightly worse quality but uses less RAM", 1),
1013
+ ],
1014
+ value= index,
1015
+ label="Text Encoder model"
1016
+ )
1017
+ def check(mode):
1018
+ if not mode in attention_modes_supported:
1019
+ return " (NOT INSTALLED)"
1020
+ else:
1021
+ return ""
1022
+ attention_choice = gr.Dropdown(
1023
+ choices=[
1024
+ ("Auto : pick sage2 > sage > sdpa depending on what is installed", "auto"),
1025
+ ("Scale Dot Product Attention: default, always available", "sdpa"),
1026
+ ("Flash" + check("flash")+ ": good quality - requires additional install (usually complex to set up on Windows without WSL)", "flash"),
1027
+ # ("Xformers" + check("xformers")+ ": good quality - requires additional install (usually complex, may consume less VRAM to set up on Windows without WSL)", "xformers"),
1028
+ ("Sage" + check("sage")+ ": 30% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage"),
1029
+ ("Sage2" + check("sage2")+ ": 40% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage2"),
1030
+ ],
1031
+ value= attention_mode,
1032
+ label="Attention Type",
1033
+ interactive= not lock_ui_attention
1034
+ )
1035
+ gr.Markdown("Beware: when restarting the server or changing a resolution or video duration, the first step of generation for a duration / resolution may last a few minutes due to recompilation")
1036
+ compile_choice = gr.Dropdown(
1037
+ choices=[
1038
+ ("ON: works only on Linux / WSL", "transformer"),
1039
+ ("OFF: no other choice if you have Windows without using WSL", "" ),
1040
+ ],
1041
+ value= compile,
1042
+ label="Compile Transformer (up to 50% faster and 30% more frames but requires Linux / WSL and Flash or Sage attention)",
1043
+ interactive= not lock_ui_compile
1044
+ )
1045
+
1046
+
1047
+ vae_config_choice = gr.Dropdown(
1048
+ choices=[
1049
+ ("Auto", 0),
1050
+ ("Disabled (faster but may require up to 24 GB of VRAM)", 1),
1051
+ ("Enabled (2x slower and up to 50% VRAM reduction)", 2),
1052
+ ],
1053
+ value= vae_config,
1054
+ label="VAE optimisations - reduce the VRAM requirements for VAE decoding and VAE encoding"
1055
+ )
1056
+
1057
+ profile_choice = gr.Dropdown(
1058
+ choices=[
1059
+ ("HighRAM_HighVRAM, profile 1: at least 48 GB of RAM and 24 GB of VRAM, the fastest for short videos a RTX 3090 / RTX 4090", 1),
1060
+ ("HighRAM_LowVRAM, profile 2 (Recommended): at least 48 GB of RAM and 12 GB of VRAM, the most versatile profile with high RAM, better suited for RTX 3070/3080/4070/4080 or for RTX 3090 / RTX 4090 with large pictures batches or long videos", 2),
1061
+ ("LowRAM_HighVRAM, profile 3: at least 32 GB of RAM and 24 GB of VRAM, adapted for RTX 3090 / RTX 4090 with limited RAM for good speed short video",3),
1062
+ ("LowRAM_LowVRAM, profile 4 (Default): at least 32 GB of RAM and 12 GB of VRAM, if you have little VRAM or want to generate longer videos",4),
1063
+ ("VerylowRAM_LowVRAM, profile 5: (Fail safe): at least 16 GB of RAM and 10 GB of VRAM, if you don't have much it won't be fast but maybe it will work",5)
1064
+ ],
1065
+ value= profile,
1066
+ label="Profile (for power users only, not needed to change it)"
1067
+ )
1068
+
1069
+ default_ui_choice = gr.Dropdown(
1070
+ choices=[
1071
+ ("Text to Video", "t2v"),
1072
+ ("Image to Video", "i2v"),
1073
+ ],
1074
+ value= default_ui,
1075
+ label="Default mode when launching the App if not '--t2v' ot '--i2v' switch is specified when launching the server ",
1076
+ # visible= True ############
1077
+ )
1078
+
1079
+ msg = gr.Markdown()
1080
+ apply_btn = gr.Button("Apply Changes")
1081
+
1082
+
1083
+ with gr.Row():
1084
+ with gr.Column():
1085
+ video_to_continue = gr.Video(label= "Video to continue", visible= use_image2video and False) #######
1086
+ image_to_continue = gr.Image(label= "Image as a starting point for a new video", visible=use_image2video)
1087
+
1088
+ if use_image2video:
1089
+ prompt = gr.Textbox(label="Prompts (multiple prompts separated by carriage returns will generate multiple videos)", value="Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field.", lines=3)
1090
+ else:
1091
+ prompt = gr.Textbox(label="Prompts (multiple prompts separated by carriage returns will generate multiple videos)", value="A large orange octopus is seen resting on the bottom of the ocean floor, blending in with the sandy and rocky terrain. Its tentacles are spread out around its body, and its eyes are closed. The octopus is unaware of a king crab that is crawling towards it from behind a rock, its claws raised and ready to attack. The crab is brown and spiny, with long legs and antennae. The scene is captured from a wide angle, showing the vastness and depth of the ocean. The water is clear and blue, with rays of sunlight filtering through. The shot is sharp and crisp, with a high dynamic range. The octopus and the crab are in focus, while the background is slightly blurred, creating a depth of field effect.", lines=3)
1092
+
1093
+
1094
+ with gr.Row():
1095
+ resolution = gr.Dropdown(
1096
+ choices=[
1097
+ # 720p
1098
+ ("1280x720 (16:9, 720p)", "1280x720"),
1099
+ ("720x1280 (9:16, 720p)", "720x1280"),
1100
+ ("1024x1024 (4:3, 720p, T2V only)", "1024x024"),
1101
+ # ("832x1104 (3:4, 720p)", "832x1104"),
1102
+ # ("960x960 (1:1, 720p)", "960x960"),
1103
+ # 480p
1104
+ # ("960x544 (16:9, 480p)", "960x544"),
1105
+ ("832x480 (16:9, 480p)", "832x480"),
1106
+ ("480x832 (9:16, 480p)", "480x832"),
1107
+ # ("832x624 (4:3, 540p)", "832x624"),
1108
+ # ("624x832 (3:4, 540p)", "624x832"),
1109
+ # ("720x720 (1:1, 540p)", "720x720"),
1110
+ ],
1111
+ value="832x480",
1112
+ label="Resolution"
1113
+ )
1114
+
1115
+ with gr.Row():
1116
+ with gr.Column():
1117
+ video_length = gr.Slider(5, 169, value=81, step=4, label="Number of frames (16 = 1s)")
1118
+ with gr.Column():
1119
+ num_inference_steps = gr.Slider(1, 100, value= default_inference_steps, step=1, label="Number of Inference Steps")
1120
+
1121
+ with gr.Row():
1122
+ max_frames = gr.Slider(1, 100, value=9, step=1, label="Number of input frames to use for Video2World prediction", visible=use_image2video and False) #########
1123
+
1124
+
1125
+ with gr.Row(visible= len(loras)>0):
1126
+ lset_choices = [ (preset, preset) for preset in loras_presets ] + [(new_preset_msg, "")]
1127
+ with gr.Column(scale=5):
1128
+ lset_name = gr.Dropdown(show_label=False, allow_custom_value= True, scale=5, filterable=False, choices= lset_choices, value=default_lora_preset)
1129
+ with gr.Column(scale=1):
1130
+ # with gr.Column():
1131
+ with gr.Row(height=17):
1132
+ apply_lset_btn = gr.Button("Apply Lora Preset", size="sm", min_width= 1)
1133
+ with gr.Row(height=17):
1134
+ save_lset_btn = gr.Button("Save", size="sm", min_width= 1)
1135
+ delete_lset_btn = gr.Button("Delete", size="sm", min_width= 1)
1136
+
1137
+
1138
+ loras_choices = gr.Dropdown(
1139
+ choices=[
1140
+ (lora_name, str(i) ) for i, lora_name in enumerate(loras_names)
1141
+ ],
1142
+ value= default_loras_choices,
1143
+ multiselect= True,
1144
+ visible= len(loras)>0,
1145
+ label="Activated Loras"
1146
+ )
1147
+ loras_mult_choices = gr.Textbox(label="Loras Multipliers (1.0 by default) separated by space characters or carriage returns", value=default_loras_multis_str, visible= len(loras)>0 )
1148
+
1149
+ show_advanced = gr.Checkbox(label="Show Advanced Options", value=False)
1150
+ with gr.Row(visible=False) as advanced_row:
1151
+ with gr.Column():
1152
+ seed = gr.Slider(-1, 999999999, value=-1, step=1, label="Seed (-1 for random)")
1153
+ repeat_generation = gr.Slider(1, 25.0, value=1.0, step=1, label="Number of Generated Video per prompt")
1154
+ with gr.Row():
1155
+ negative_prompt = gr.Textbox(label="Negative Prompt", value="")
1156
+ with gr.Row():
1157
+ guidance_scale = gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="Guidance Scale", visible=True)
1158
+ embedded_guidance_scale = gr.Slider(1.0, 20.0, value=6.0, step=0.5, label="Embedded Guidance Scale", visible=False)
1159
+ flow_shift = gr.Slider(0.0, 25.0, value= default_flow_shift, step=0.1, label="Shift Scale")
1160
+ tea_cache_setting = gr.Dropdown(
1161
+ choices=[
1162
+ ("Disabled", 0),
1163
+ ("Fast (x1.6 speed up)", 0.1),
1164
+ ("Faster (x2.1 speed up)", 0.15),
1165
+ ],
1166
+ value=default_tea_cache,
1167
+ visible=False,
1168
+ label="Tea Cache acceleration (the faster the acceleration the higher the degradation of the quality of the video. Consumes VRAM)"
1169
+ )
1170
+
1171
+ RIFLEx_setting = gr.Dropdown(
1172
+ choices=[
1173
+ ("Auto (ON if Video longer than 5s)", 0),
1174
+ ("Always ON", 1),
1175
+ ("Always OFF", 2),
1176
+ ],
1177
+ value=0,
1178
+ label="RIFLEx positional embedding to generate long video"
1179
+ )
1180
+
1181
+ show_advanced.change(fn=lambda x: gr.Row(visible=x), inputs=[show_advanced], outputs=[advanced_row])
1182
+
1183
+ with gr.Column():
1184
+ gen_status = gr.Text(label="Status", interactive= False)
1185
+ output = gr.Gallery(
1186
+ label="Generated videos", show_label=False, elem_id="gallery"
1187
+ , columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= False)
1188
+ generate_btn = gr.Button("Generate")
1189
+ abort_btn = gr.Button("Abort")
1190
+
1191
+ save_lset_btn.click(save_lset, inputs=[lset_name, loras_choices, loras_mult_choices], outputs=[lset_name])
1192
+ delete_lset_btn.click(delete_lset, inputs=[lset_name], outputs=[lset_name])
1193
+ apply_lset_btn.click(apply_lset, inputs=[lset_name,loras_choices, loras_mult_choices], outputs=[loras_choices, loras_mult_choices])
1194
+
1195
+ gen_status.change(refresh_gallery, inputs = [state], outputs = output )
1196
+
1197
+ abort_btn.click(abort_generation,state,abort_btn )
1198
+ output.select(select_video, state, None )
1199
+
1200
+ generate_btn.click(
1201
+ fn=generate_video,
1202
+ inputs=[
1203
+ prompt,
1204
+ negative_prompt,
1205
+ resolution,
1206
+ video_length,
1207
+ seed,
1208
+ num_inference_steps,
1209
+ guidance_scale,
1210
+ flow_shift,
1211
+ embedded_guidance_scale,
1212
+ repeat_generation,
1213
+ tea_cache_setting,
1214
+ loras_choices,
1215
+ loras_mult_choices,
1216
+ image_to_continue,
1217
+ video_to_continue,
1218
+ max_frames,
1219
+ RIFLEx_setting,
1220
+ state
1221
+ ],
1222
+ outputs= [gen_status] #,state
1223
+
1224
+ ).then(
1225
+ finalize_gallery,
1226
+ [state],
1227
+ [output , abort_btn]
1228
+ )
1229
+
1230
+ apply_btn.click(
1231
+ fn=apply_changes,
1232
+ inputs=[
1233
+ state,
1234
+ transformer_t2v_choice,
1235
+ transformer_i2v_choice,
1236
+ text_encoder_choice,
1237
+ attention_choice,
1238
+ compile_choice,
1239
+ profile_choice,
1240
+ vae_config_choice,
1241
+ default_ui_choice,
1242
+ ],
1243
+ outputs= msg
1244
+ ).then(
1245
+ update_defaults,
1246
+ [state, num_inference_steps, flow_shift],
1247
+ [num_inference_steps, flow_shift, header]
1248
+ )
1249
+
1250
+ return demo
1251
+
1252
+ if __name__ == "__main__":
1253
+ os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
1254
+ server_port = int(args.server_port)
1255
+
1256
+ if server_port == 0:
1257
+ server_port = int(os.getenv("SERVER_PORT", "7860"))
1258
+
1259
+ server_name = args.server_name
1260
+ if len(server_name) == 0:
1261
+ server_name = os.getenv("SERVER_NAME", "localhost")
1262
+
1263
+
1264
+ demo = create_demo()
1265
+ if args.open_browser:
1266
+ import webbrowser
1267
+ if server_name.startswith("http"):
1268
+ url = server_name
1269
+ else:
1270
+ url = "http://" + server_name
1271
+ webbrowser.open(url + ":" + str(server_port), new = 0, autoraise = True)
1272
+
1273
+ demo.launch(server_name=server_name, server_port=server_port)
1274
+
1275
+
loras/README.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Put here Loras
loras_i2v/README.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Put here Loras
requirements.txt CHANGED
@@ -11,6 +11,9 @@ easydict
11
  ftfy
12
  dashscope
13
  imageio-ffmpeg
14
- flash_attn
15
  gradio>=5.0.0
16
  numpy>=1.23.5,<2
 
 
 
 
11
  ftfy
12
  dashscope
13
  imageio-ffmpeg
14
+ # flash_attn
15
  gradio>=5.0.0
16
  numpy>=1.23.5,<2
17
+ einops
18
+ moviepy==1.0.3
19
+ mmgp==3.2.1
wan/image2video.py CHANGED
@@ -39,6 +39,9 @@ class WanI2V:
39
  use_usp=False,
40
  t5_cpu=False,
41
  init_on_cpu=True,
 
 
 
42
  ):
43
  r"""
44
  Initializes the image-to-video generation model components.
@@ -77,7 +80,7 @@ class WanI2V:
77
  text_len=config.text_len,
78
  dtype=config.t5_dtype,
79
  device=torch.device('cpu'),
80
- checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
81
  tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
82
  shard_fn=shard_fn if t5_fsdp else None,
83
  )
@@ -95,8 +98,10 @@ class WanI2V:
95
  config.clip_checkpoint),
96
  tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
97
 
98
- logging.info(f"Creating WanModel from {checkpoint_dir}")
99
- self.model = WanModel.from_pretrained(checkpoint_dir)
 
 
100
  self.model.eval().requires_grad_(False)
101
 
102
  if t5_fsdp or dit_fsdp or use_usp:
@@ -116,28 +121,30 @@ class WanI2V:
116
  else:
117
  self.sp_size = 1
118
 
119
- if dist.is_initialized():
120
- dist.barrier()
121
- if dit_fsdp:
122
- self.model = shard_fn(self.model)
123
- else:
124
- if not init_on_cpu:
125
- self.model.to(self.device)
126
 
127
  self.sample_neg_prompt = config.sample_neg_prompt
128
 
129
  def generate(self,
130
- input_prompt,
131
- img,
132
- max_area=720 * 1280,
133
- frame_num=81,
134
- shift=5.0,
135
- sample_solver='unipc',
136
- sampling_steps=40,
137
- guide_scale=5.0,
138
- n_prompt="",
139
- seed=-1,
140
- offload_model=True):
 
 
141
  r"""
142
  Generates video frames from input image and text prompt using diffusion process.
143
 
@@ -197,14 +204,14 @@ class WanI2V:
197
  seed_g.manual_seed(seed)
198
  noise = torch.randn(
199
  16,
200
- 21,
201
  lat_h,
202
  lat_w,
203
  dtype=torch.float32,
204
  generator=seed_g,
205
  device=self.device)
206
 
207
- msk = torch.ones(1, 81, lat_h, lat_w, device=self.device)
208
  msk[:, 1:] = 0
209
  msk = torch.concat([
210
  torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]
@@ -218,7 +225,7 @@ class WanI2V:
218
 
219
  # preprocess
220
  if not self.t5_cpu:
221
- self.text_encoder.model.to(self.device)
222
  context = self.text_encoder([input_prompt], self.device)
223
  context_null = self.text_encoder([n_prompt], self.device)
224
  if offload_model:
@@ -229,20 +236,23 @@ class WanI2V:
229
  context = [t.to(self.device) for t in context]
230
  context_null = [t.to(self.device) for t in context_null]
231
 
232
- self.clip.model.to(self.device)
233
  clip_context = self.clip.visual([img[:, None, :, :]])
234
  if offload_model:
235
  self.clip.model.cpu()
236
 
237
- y = self.vae.encode([
238
- torch.concat([
239
- torch.nn.functional.interpolate(
240
- img[None].cpu(), size=(h, w), mode='bicubic').transpose(
241
- 0, 1),
242
- torch.zeros(3, 80, h, w)
243
- ],
244
- dim=1).to(self.device)
245
- ])[0]
 
 
 
246
  y = torch.concat([msk, y])
247
 
248
  @contextmanager
@@ -283,6 +293,7 @@ class WanI2V:
283
  'clip_fea': clip_context,
284
  'seq_len': max_seq_len,
285
  'y': [y],
 
286
  }
287
 
288
  arg_null = {
@@ -290,30 +301,39 @@ class WanI2V:
290
  'clip_fea': clip_context,
291
  'seq_len': max_seq_len,
292
  'y': [y],
 
293
  }
294
 
295
  if offload_model:
296
  torch.cuda.empty_cache()
297
 
298
- self.model.to(self.device)
299
- for _, t in enumerate(tqdm(timesteps)):
 
 
 
 
300
  latent_model_input = [latent.to(self.device)]
301
  timestep = [t]
302
 
303
  timestep = torch.stack(timestep).to(self.device)
304
 
305
  noise_pred_cond = self.model(
306
- latent_model_input, t=timestep, **arg_c)[0].to(
307
- torch.device('cpu') if offload_model else self.device)
 
308
  if offload_model:
309
  torch.cuda.empty_cache()
310
  noise_pred_uncond = self.model(
311
- latent_model_input, t=timestep, **arg_null)[0].to(
312
- torch.device('cpu') if offload_model else self.device)
 
 
313
  if offload_model:
314
  torch.cuda.empty_cache()
315
  noise_pred = noise_pred_uncond + guide_scale * (
316
  noise_pred_cond - noise_pred_uncond)
 
317
 
318
  latent = latent.to(
319
  torch.device('cpu') if offload_model else self.device)
@@ -325,9 +345,14 @@ class WanI2V:
325
  return_dict=False,
326
  generator=seed_g)[0]
327
  latent = temp_x0.squeeze(0)
 
 
 
 
 
 
328
 
329
- x0 = [latent.to(self.device)]
330
- del latent_model_input, timestep
331
 
332
  if offload_model:
333
  self.model.cpu()
 
39
  use_usp=False,
40
  t5_cpu=False,
41
  init_on_cpu=True,
42
+ i2v720p= True,
43
+ model_filename ="",
44
+ text_encoder_filename="",
45
  ):
46
  r"""
47
  Initializes the image-to-video generation model components.
 
80
  text_len=config.text_len,
81
  dtype=config.t5_dtype,
82
  device=torch.device('cpu'),
83
+ checkpoint_path=text_encoder_filename,
84
  tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
85
  shard_fn=shard_fn if t5_fsdp else None,
86
  )
 
98
  config.clip_checkpoint),
99
  tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
100
 
101
+ logging.info(f"Creating WanModel from {model_filename}")
102
+ from mmgp import offload
103
+
104
+ self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel)
105
  self.model.eval().requires_grad_(False)
106
 
107
  if t5_fsdp or dit_fsdp or use_usp:
 
121
  else:
122
  self.sp_size = 1
123
 
124
+ # if dist.is_initialized():
125
+ # dist.barrier()
126
+ # if dit_fsdp:
127
+ # self.model = shard_fn(self.model)
128
+ # else:
129
+ # if not init_on_cpu:
130
+ # self.model.to(self.device)
131
 
132
  self.sample_neg_prompt = config.sample_neg_prompt
133
 
134
  def generate(self,
135
+ input_prompt,
136
+ img,
137
+ max_area=720 * 1280,
138
+ frame_num=81,
139
+ shift=5.0,
140
+ sample_solver='unipc',
141
+ sampling_steps=40,
142
+ guide_scale=5.0,
143
+ n_prompt="",
144
+ seed=-1,
145
+ offload_model=True,
146
+ callback = None
147
+ ):
148
  r"""
149
  Generates video frames from input image and text prompt using diffusion process.
150
 
 
204
  seed_g.manual_seed(seed)
205
  noise = torch.randn(
206
  16,
207
+ int((frame_num - 1)/4 + 1), #21,
208
  lat_h,
209
  lat_w,
210
  dtype=torch.float32,
211
  generator=seed_g,
212
  device=self.device)
213
 
214
+ msk = torch.ones(1, frame_num, lat_h, lat_w, device=self.device)
215
  msk[:, 1:] = 0
216
  msk = torch.concat([
217
  torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]
 
225
 
226
  # preprocess
227
  if not self.t5_cpu:
228
+ # self.text_encoder.model.to(self.device)
229
  context = self.text_encoder([input_prompt], self.device)
230
  context_null = self.text_encoder([n_prompt], self.device)
231
  if offload_model:
 
236
  context = [t.to(self.device) for t in context]
237
  context_null = [t.to(self.device) for t in context_null]
238
 
239
+ # self.clip.model.to(self.device)
240
  clip_context = self.clip.visual([img[:, None, :, :]])
241
  if offload_model:
242
  self.clip.model.cpu()
243
 
244
+ from mmgp import offload
245
+
246
+ offload.last_offload_obj.unload_all()
247
+ enc= torch.concat([
248
+ torch.nn.functional.interpolate(
249
+ img[None].cpu(), size=(h, w), mode='bicubic').transpose(
250
+ 0, 1).to(torch.bfloat16),
251
+ torch.zeros(3, frame_num-1, h, w, device="cpu", dtype= torch.bfloat16)
252
+ ], dim=1).to(self.device)
253
+ # enc = None
254
+
255
+ y = self.vae.encode([enc])[0]
256
  y = torch.concat([msk, y])
257
 
258
  @contextmanager
 
293
  'clip_fea': clip_context,
294
  'seq_len': max_seq_len,
295
  'y': [y],
296
+ 'pipeline' : self
297
  }
298
 
299
  arg_null = {
 
301
  'clip_fea': clip_context,
302
  'seq_len': max_seq_len,
303
  'y': [y],
304
+ 'pipeline' : self
305
  }
306
 
307
  if offload_model:
308
  torch.cuda.empty_cache()
309
 
310
+ # self.model.to(self.device)
311
+ if callback != None:
312
+ callback(-1, None)
313
+
314
+ self._interrupt = False
315
+ for i, t in enumerate(tqdm(timesteps)):
316
  latent_model_input = [latent.to(self.device)]
317
  timestep = [t]
318
 
319
  timestep = torch.stack(timestep).to(self.device)
320
 
321
  noise_pred_cond = self.model(
322
+ latent_model_input, t=timestep, **arg_c)[0]
323
+ if self._interrupt:
324
+ return None
325
  if offload_model:
326
  torch.cuda.empty_cache()
327
  noise_pred_uncond = self.model(
328
+ latent_model_input, t=timestep, **arg_null)[0]
329
+ if self._interrupt:
330
+ return None
331
+ del latent_model_input
332
  if offload_model:
333
  torch.cuda.empty_cache()
334
  noise_pred = noise_pred_uncond + guide_scale * (
335
  noise_pred_cond - noise_pred_uncond)
336
+ del noise_pred_uncond
337
 
338
  latent = latent.to(
339
  torch.device('cpu') if offload_model else self.device)
 
345
  return_dict=False,
346
  generator=seed_g)[0]
347
  latent = temp_x0.squeeze(0)
348
+ del temp_x0
349
+ del timestep
350
+
351
+ if callback is not None:
352
+ callback(i, latent)
353
+
354
 
355
+ x0 = [latent.to(self.device)]
 
356
 
357
  if offload_model:
358
  self.model.cpu()
wan/modules/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- from .attention import flash_attention
2
  from .model import WanModel
3
  from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model
4
  from .tokenizers import HuggingfaceTokenizer
@@ -12,5 +12,5 @@ __all__ = [
12
  'T5Decoder',
13
  'T5EncoderModel',
14
  'HuggingfaceTokenizer',
15
- 'flash_attention',
16
  ]
 
1
+ from .attention import pay_attention
2
  from .model import WanModel
3
  from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model
4
  from .tokenizers import HuggingfaceTokenizer
 
12
  'T5Decoder',
13
  'T5EncoderModel',
14
  'HuggingfaceTokenizer',
15
+ 'pay_attention',
16
  ]
wan/modules/attention.py CHANGED
@@ -1,5 +1,9 @@
1
  # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
  import torch
 
 
 
 
3
 
4
  try:
5
  import flash_attn_interface
@@ -12,19 +16,99 @@ try:
12
  FLASH_ATTN_2_AVAILABLE = True
13
  except ModuleNotFoundError:
14
  FLASH_ATTN_2_AVAILABLE = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  import warnings
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  __all__ = [
19
- 'flash_attention',
20
  'attention',
21
  ]
22
 
23
 
24
- def flash_attention(
25
- q,
26
- k,
27
- v,
 
28
  q_lens=None,
29
  k_lens=None,
30
  dropout_p=0.,
@@ -49,6 +133,10 @@ def flash_attention(
49
  deterministic: bool. If True, slightly slower and uses more memory.
50
  dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
51
  """
 
 
 
 
52
  half_dtypes = (torch.float16, torch.bfloat16)
53
  assert dtype in half_dtypes
54
  assert q.device.type == 'cuda' and q.size(-1) <= 256
@@ -91,7 +179,27 @@ def flash_attention(
91
  )
92
 
93
  # apply attention
94
- if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  # Note: dropout_p, window_size are not supported in FA3 now.
96
  x = flash_attn_interface.flash_attn_varlen_func(
97
  q=q,
@@ -108,8 +216,7 @@ def flash_attention(
108
  softmax_scale=softmax_scale,
109
  causal=causal,
110
  deterministic=deterministic)[0].unflatten(0, (b, lq))
111
- else:
112
- assert FLASH_ATTN_2_AVAILABLE
113
  x = flash_attn.flash_attn_varlen_func(
114
  q=q,
115
  k=k,
@@ -146,7 +253,7 @@ def attention(
146
  fa_version=None,
147
  ):
148
  if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
149
- return flash_attention(
150
  q=q,
151
  k=k,
152
  v=v,
 
1
  # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
  import torch
3
+ from importlib.metadata import version
4
+ from mmgp import offload
5
+ import torch.nn.functional as F
6
+
7
 
8
  try:
9
  import flash_attn_interface
 
16
  FLASH_ATTN_2_AVAILABLE = True
17
  except ModuleNotFoundError:
18
  FLASH_ATTN_2_AVAILABLE = False
19
+ flash_attn = None
20
+
21
+ try:
22
+ from sageattention import sageattn_varlen
23
+ def sageattn_varlen_wrapper(
24
+ q,
25
+ k,
26
+ v,
27
+ cu_seqlens_q,
28
+ cu_seqlens_kv,
29
+ max_seqlen_q,
30
+ max_seqlen_kv,
31
+ ):
32
+ return sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
33
+ except ImportError:
34
+ sageattn_varlen_wrapper = None
35
+
36
 
37
  import warnings
38
 
39
+ try:
40
+ from sageattention import sageattn
41
+ @torch.compiler.disable()
42
+ def sageattn_wrapper(
43
+ qkv_list,
44
+ attention_length
45
+ ):
46
+ q,k, v = qkv_list
47
+ padding_length = q.shape[0] -attention_length
48
+ q = q[:attention_length, :, : ].unsqueeze(0)
49
+ k = k[:attention_length, :, : ].unsqueeze(0)
50
+ v = v[:attention_length, :, : ].unsqueeze(0)
51
+
52
+ o = sageattn(q, k, v, tensor_layout="NHD").squeeze(0)
53
+ del q, k ,v
54
+ qkv_list.clear()
55
+
56
+ if padding_length > 0:
57
+ o = torch.cat([o, torch.empty( (padding_length, *o.shape[-2:]), dtype= o.dtype, device=o.device ) ], 0)
58
+
59
+ return o
60
+ except ImportError:
61
+ sageattn = None
62
+
63
+
64
+ @torch.compiler.disable()
65
+ def sdpa_wrapper(
66
+ qkv_list,
67
+ attention_length
68
+ ):
69
+ q,k, v = qkv_list
70
+ padding_length = q.shape[0] -attention_length
71
+ q = q[:attention_length, :].transpose(0,1).unsqueeze(0)
72
+ k = k[:attention_length, :].transpose(0,1).unsqueeze(0)
73
+ v = v[:attention_length, :].transpose(0,1).unsqueeze(0)
74
+
75
+ o = F.scaled_dot_product_attention(
76
+ q, k, v, attn_mask=None, is_causal=False
77
+ ).squeeze(0).transpose(0,1)
78
+ del q, k ,v
79
+ qkv_list.clear()
80
+
81
+ if padding_length > 0:
82
+ o = torch.cat([o, torch.empty( (padding_length, *o.shape[-2:]), dtype= o.dtype, device=o.device ) ], 0)
83
+
84
+ return o
85
+
86
+
87
+ def get_attention_modes():
88
+ ret = ["sdpa", "auto"]
89
+ if flash_attn != None:
90
+ ret.append("flash")
91
+ # if memory_efficient_attention != None:
92
+ # ret.append("xformers")
93
+ if sageattn_varlen_wrapper != None:
94
+ ret.append("sage")
95
+ if sageattn != None and version("sageattention").startswith("2") :
96
+ ret.append("sage2")
97
+
98
+ return ret
99
+
100
+
101
  __all__ = [
102
+ 'pay_attention',
103
  'attention',
104
  ]
105
 
106
 
107
+ def pay_attention(
108
+ qkv_list,
109
+ # q,
110
+ # k,
111
+ # v,
112
  q_lens=None,
113
  k_lens=None,
114
  dropout_p=0.,
 
133
  deterministic: bool. If True, slightly slower and uses more memory.
134
  dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
135
  """
136
+ attn = offload.shared_state["_attention"]
137
+ q,k,v = qkv_list
138
+ qkv_list.clear()
139
+
140
  half_dtypes = (torch.float16, torch.bfloat16)
141
  assert dtype in half_dtypes
142
  assert q.device.type == 'cuda' and q.size(-1) <= 256
 
179
  )
180
 
181
  # apply attention
182
+ if attn=="sage":
183
+ x = sageattn_varlen_wrapper(
184
+ q=q,
185
+ k=k,
186
+ v=v,
187
+ cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
188
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
189
+ cu_seqlens_kv=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
190
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
191
+ max_seqlen_q=lq,
192
+ max_seqlen_kv=lk,
193
+ ).unflatten(0, (b, lq))
194
+ elif attn=="sage2":
195
+ qkv_list = [q,k,v]
196
+ del q,k,v
197
+ x = sageattn_wrapper(qkv_list, lq).unsqueeze(0)
198
+ elif attn=="sdpa":
199
+ qkv_list = [q, k, v]
200
+ del q, k , v
201
+ x = sdpa_wrapper( qkv_list, lq).unsqueeze(0)
202
+ elif attn=="flash" and (version is None or version == 3):
203
  # Note: dropout_p, window_size are not supported in FA3 now.
204
  x = flash_attn_interface.flash_attn_varlen_func(
205
  q=q,
 
216
  softmax_scale=softmax_scale,
217
  causal=causal,
218
  deterministic=deterministic)[0].unflatten(0, (b, lq))
219
+ elif attn=="flash":
 
220
  x = flash_attn.flash_attn_varlen_func(
221
  q=q,
222
  k=k,
 
253
  fa_version=None,
254
  ):
255
  if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
256
+ return pay_attention(
257
  q=q,
258
  k=k,
259
  v=v,
wan/modules/clip.py CHANGED
@@ -8,7 +8,7 @@ import torch.nn as nn
8
  import torch.nn.functional as F
9
  import torchvision.transforms as T
10
 
11
- from .attention import flash_attention
12
  from .tokenizers import HuggingfaceTokenizer
13
  from .xlm_roberta import XLMRoberta
14
 
@@ -82,7 +82,7 @@ class SelfAttention(nn.Module):
82
 
83
  # compute attention
84
  p = self.attn_dropout if self.training else 0.0
85
- x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2)
86
  x = x.reshape(b, s, c)
87
 
88
  # output
@@ -194,7 +194,7 @@ class AttentionPool(nn.Module):
194
  k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
195
 
196
  # compute attention
197
- x = flash_attention(q, k, v, version=2)
198
  x = x.reshape(b, 1, c)
199
 
200
  # output
@@ -441,11 +441,12 @@ def _clip(pretrained=False,
441
  device='cpu',
442
  **kwargs):
443
  # init a model on device
 
444
  with torch.device(device):
445
  model = model_cls(**kwargs)
446
 
447
  # set device
448
- model = model.to(dtype=dtype, device=device)
449
  output = (model,)
450
 
451
  # init transforms
@@ -507,16 +508,19 @@ class CLIPModel:
507
  self.tokenizer_path = tokenizer_path
508
 
509
  # init model
510
- self.model, self.transforms = clip_xlm_roberta_vit_h_14(
511
- pretrained=False,
512
- return_transforms=True,
513
- return_tokenizer=False,
514
- dtype=dtype,
515
- device=device)
 
 
 
516
  self.model = self.model.eval().requires_grad_(False)
517
  logging.info(f'loading {checkpoint_path}')
518
  self.model.load_state_dict(
519
- torch.load(checkpoint_path, map_location='cpu'))
520
 
521
  # init tokenizer
522
  self.tokenizer = HuggingfaceTokenizer(
 
8
  import torch.nn.functional as F
9
  import torchvision.transforms as T
10
 
11
+ from .attention import pay_attention
12
  from .tokenizers import HuggingfaceTokenizer
13
  from .xlm_roberta import XLMRoberta
14
 
 
82
 
83
  # compute attention
84
  p = self.attn_dropout if self.training else 0.0
85
+ x = pay_attention([q, k, v], dropout_p=p, causal=self.causal, version=2)
86
  x = x.reshape(b, s, c)
87
 
88
  # output
 
194
  k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
195
 
196
  # compute attention
197
+ x = pay_attention(q, k, v, version=2)
198
  x = x.reshape(b, 1, c)
199
 
200
  # output
 
441
  device='cpu',
442
  **kwargs):
443
  # init a model on device
444
+ device ="cpu"
445
  with torch.device(device):
446
  model = model_cls(**kwargs)
447
 
448
  # set device
449
+ # model = model.to(dtype=dtype, device=device)
450
  output = (model,)
451
 
452
  # init transforms
 
508
  self.tokenizer_path = tokenizer_path
509
 
510
  # init model
511
+ from accelerate import init_empty_weights
512
+
513
+ with init_empty_weights():
514
+ self.model, self.transforms = clip_xlm_roberta_vit_h_14(
515
+ pretrained=False,
516
+ return_transforms=True,
517
+ return_tokenizer=False,
518
+ dtype=dtype,
519
+ device=device)
520
  self.model = self.model.eval().requires_grad_(False)
521
  logging.info(f'loading {checkpoint_path}')
522
  self.model.load_state_dict(
523
+ torch.load(checkpoint_path, map_location='cpu'), assign= True)
524
 
525
  # init tokenizer
526
  self.tokenizer = HuggingfaceTokenizer(
wan/modules/model.py CHANGED
@@ -7,7 +7,7 @@ import torch.nn as nn
7
  from diffusers.configuration_utils import ConfigMixin, register_to_config
8
  from diffusers.models.modeling_utils import ModelMixin
9
 
10
- from .attention import flash_attention
11
 
12
  __all__ = ['WanModel']
13
 
@@ -16,7 +16,7 @@ def sinusoidal_embedding_1d(dim, position):
16
  # preprocess
17
  assert dim % 2 == 0
18
  half = dim // 2
19
- position = position.type(torch.float64)
20
 
21
  # calculation
22
  sinusoid = torch.outer(
@@ -25,18 +25,47 @@ def sinusoidal_embedding_1d(dim, position):
25
  return x
26
 
27
 
28
- @amp.autocast(enabled=False)
29
  def rope_params(max_seq_len, dim, theta=10000):
30
  assert dim % 2 == 0
31
  freqs = torch.outer(
32
  torch.arange(max_seq_len),
33
  1.0 / torch.pow(theta,
34
- torch.arange(0, dim, 2).to(torch.float64).div(dim)))
35
  freqs = torch.polar(torch.ones_like(freqs), freqs)
36
  return freqs
37
 
38
 
39
- @amp.autocast(enabled=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  def rope_apply(x, grid_sizes, freqs):
41
  n, c = x.size(2), x.size(3) // 2
42
 
@@ -45,12 +74,17 @@ def rope_apply(x, grid_sizes, freqs):
45
 
46
  # loop over samples
47
  output = []
48
- for i, (f, h, w) in enumerate(grid_sizes.tolist()):
49
  seq_len = f * h * w
50
 
51
  # precompute multipliers
52
- x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(
53
- seq_len, n, -1, 2))
 
 
 
 
 
54
  freqs_i = torch.cat([
55
  freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
56
  freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
@@ -59,12 +93,14 @@ def rope_apply(x, grid_sizes, freqs):
59
  dim=-1).reshape(seq_len, 1, -1)
60
 
61
  # apply rotary embedding
62
- x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
 
 
63
  x_i = torch.cat([x_i, x[i, seq_len:]])
64
 
65
  # append to collection
66
  output.append(x_i)
67
- return torch.stack(output).float()
68
 
69
 
70
  class WanRMSNorm(nn.Module):
@@ -80,11 +116,31 @@ class WanRMSNorm(nn.Module):
80
  Args:
81
  x(Tensor): Shape [B, L, C]
82
  """
83
- return self._norm(x.float()).type_as(x) * self.weight
 
 
 
 
 
 
 
 
84
 
85
  def _norm(self, x):
86
  return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
87
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  class WanLayerNorm(nn.LayerNorm):
90
 
@@ -96,7 +152,13 @@ class WanLayerNorm(nn.LayerNorm):
96
  Args:
97
  x(Tensor): Shape [B, L, C]
98
  """
99
- return super().forward(x.float()).type_as(x)
 
 
 
 
 
 
100
 
101
 
102
  class WanSelfAttention(nn.Module):
@@ -124,7 +186,7 @@ class WanSelfAttention(nn.Module):
124
  self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
125
  self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
126
 
127
- def forward(self, x, seq_lens, grid_sizes, freqs):
128
  r"""
129
  Args:
130
  x(Tensor): Shape [B, L, num_heads, C / num_heads]
@@ -132,24 +194,31 @@ class WanSelfAttention(nn.Module):
132
  grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
133
  freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
134
  """
 
 
 
135
  b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
136
 
137
  # query, key, value function
138
- def qkv_fn(x):
139
- q = self.norm_q(self.q(x)).view(b, s, n, d)
140
- k = self.norm_k(self.k(x)).view(b, s, n, d)
141
- v = self.v(x).view(b, s, n, d)
142
- return q, k, v
143
-
144
- q, k, v = qkv_fn(x)
145
-
146
- x = flash_attention(
147
- q=rope_apply(q, grid_sizes, freqs),
148
- k=rope_apply(k, grid_sizes, freqs),
149
- v=v,
150
- k_lens=seq_lens,
 
 
 
 
 
151
  window_size=self.window_size)
152
-
153
  # output
154
  x = x.flatten(2)
155
  x = self.o(x)
@@ -158,22 +227,31 @@ class WanSelfAttention(nn.Module):
158
 
159
  class WanT2VCrossAttention(WanSelfAttention):
160
 
161
- def forward(self, x, context, context_lens):
162
  r"""
163
  Args:
164
  x(Tensor): Shape [B, L1, C]
165
  context(Tensor): Shape [B, L2, C]
166
  context_lens(Tensor): Shape [B]
167
  """
 
 
168
  b, n, d = x.size(0), self.num_heads, self.head_dim
169
 
170
  # compute query, key, value
171
- q = self.norm_q(self.q(x)).view(b, -1, n, d)
172
- k = self.norm_k(self.k(context)).view(b, -1, n, d)
 
 
 
 
 
173
  v = self.v(context).view(b, -1, n, d)
174
 
175
  # compute attention
176
- x = flash_attention(q, k, v, k_lens=context_lens)
 
 
177
 
178
  # output
179
  x = x.flatten(2)
@@ -196,31 +274,54 @@ class WanI2VCrossAttention(WanSelfAttention):
196
  # self.alpha = nn.Parameter(torch.zeros((1, )))
197
  self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
198
 
199
- def forward(self, x, context, context_lens):
200
  r"""
201
  Args:
202
  x(Tensor): Shape [B, L1, C]
203
  context(Tensor): Shape [B, L2, C]
204
  context_lens(Tensor): Shape [B]
205
  """
 
 
 
 
 
 
 
 
206
  context_img = context[:, :257]
207
  context = context[:, 257:]
208
  b, n, d = x.size(0), self.num_heads, self.head_dim
209
 
210
  # compute query, key, value
211
- q = self.norm_q(self.q(x)).view(b, -1, n, d)
212
- k = self.norm_k(self.k(context)).view(b, -1, n, d)
 
 
 
 
 
213
  v = self.v(context).view(b, -1, n, d)
214
- k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
 
 
 
 
 
 
 
215
  v_img = self.v_img(context_img).view(b, -1, n, d)
216
- img_x = flash_attention(q, k_img, v_img, k_lens=None)
 
 
217
  # compute attention
218
- x = flash_attention(q, k, v, k_lens=context_lens)
219
 
220
  # output
221
  x = x.flatten(2)
222
  img_x = img_x.flatten(2)
223
- x = x + img_x
 
224
  x = self.o(x)
225
  return x
226
 
@@ -289,27 +390,46 @@ class WanAttentionBlock(nn.Module):
289
  grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
290
  freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
291
  """
292
- assert e.dtype == torch.float32
293
- with amp.autocast(dtype=torch.float32):
294
- e = (self.modulation + e).chunk(6, dim=1)
295
- assert e[0].dtype == torch.float32
296
-
297
  # self-attention
298
- y = self.self_attn(
299
- self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes,
300
- freqs)
301
- with amp.autocast(dtype=torch.float32):
302
- x = x + y * e[2]
303
-
304
- # cross-attention & ffn function
305
- def cross_attn_ffn(x, context, context_lens, e):
306
- x = x + self.cross_attn(self.norm3(x), context, context_lens)
307
- y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
308
- with amp.autocast(dtype=torch.float32):
309
- x = x + y * e[5]
310
- return x
311
-
312
- x = cross_attn_ffn(x, context, context_lens, e)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
  return x
314
 
315
 
@@ -336,10 +456,13 @@ class Head(nn.Module):
336
  x(Tensor): Shape [B, L1, C]
337
  e(Tensor): Shape [B, C]
338
  """
339
- assert e.dtype == torch.float32
340
- with amp.autocast(dtype=torch.float32):
341
- e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
342
- x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
 
 
 
343
  return x
344
 
345
 
@@ -384,7 +507,8 @@ class WanModel(ModelMixin, ConfigMixin):
384
  window_size=(-1, -1),
385
  qk_norm=True,
386
  cross_attn_norm=True,
387
- eps=1e-6):
 
388
  r"""
389
  Initialize the diffusion model backbone.
390
 
@@ -466,7 +590,7 @@ class WanModel(ModelMixin, ConfigMixin):
466
  # buffers (don't use register_buffer otherwise dtype will be changed in to())
467
  assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
468
  d = dim // num_heads
469
- self.freqs = torch.cat([
470
  rope_params(1024, d - 4 * (d // 6)),
471
  rope_params(1024, 2 * (d // 6)),
472
  rope_params(1024, 2 * (d // 6))
@@ -487,6 +611,7 @@ class WanModel(ModelMixin, ConfigMixin):
487
  seq_len,
488
  clip_fea=None,
489
  y=None,
 
490
  ):
491
  r"""
492
  Forward pass through the diffusion model
@@ -521,8 +646,11 @@ class WanModel(ModelMixin, ConfigMixin):
521
 
522
  # embeddings
523
  x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
524
- grid_sizes = torch.stack(
525
- [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
 
 
 
526
  x = [u.flatten(2).transpose(1, 2) for u in x]
527
  seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
528
  assert seq_lens.max() <= seq_len
@@ -532,11 +660,10 @@ class WanModel(ModelMixin, ConfigMixin):
532
  ])
533
 
534
  # time embeddings
535
- with amp.autocast(dtype=torch.float32):
536
- e = self.time_embedding(
537
- sinusoidal_embedding_1d(self.freq_dim, t).float())
538
- e0 = self.time_projection(e).unflatten(1, (6, self.dim))
539
- assert e.dtype == torch.float32 and e0.dtype == torch.float32
540
 
541
  # context
542
  context_lens = None
@@ -561,6 +688,9 @@ class WanModel(ModelMixin, ConfigMixin):
561
  context_lens=context_lens)
562
 
563
  for block in self.blocks:
 
 
 
564
  x = block(x, **kwargs)
565
 
566
  # head
@@ -588,7 +718,7 @@ class WanModel(ModelMixin, ConfigMixin):
588
 
589
  c = self.out_dim
590
  out = []
591
- for u, v in zip(x, grid_sizes.tolist()):
592
  u = u[:math.prod(v)].view(*v, *self.patch_size, c)
593
  u = torch.einsum('fhwpqrc->cfphqwr', u)
594
  u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
 
7
  from diffusers.configuration_utils import ConfigMixin, register_to_config
8
  from diffusers.models.modeling_utils import ModelMixin
9
 
10
+ from .attention import pay_attention
11
 
12
  __all__ = ['WanModel']
13
 
 
16
  # preprocess
17
  assert dim % 2 == 0
18
  half = dim // 2
19
+ position = position.type(torch.float32)
20
 
21
  # calculation
22
  sinusoid = torch.outer(
 
25
  return x
26
 
27
 
28
+ # @amp.autocast(enabled=False)
29
  def rope_params(max_seq_len, dim, theta=10000):
30
  assert dim % 2 == 0
31
  freqs = torch.outer(
32
  torch.arange(max_seq_len),
33
  1.0 / torch.pow(theta,
34
+ torch.arange(0, dim, 2).to(torch.float32).div(dim)))
35
  freqs = torch.polar(torch.ones_like(freqs), freqs)
36
  return freqs
37
 
38
 
39
+ def rope_apply_(x, grid_sizes, freqs):
40
+ assert x.shape[0]==1
41
+
42
+ n, c = x.size(2), x.size(3) // 2
43
+
44
+ # split freqs
45
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
46
+
47
+ f, h, w = grid_sizes[0]
48
+ seq_len = f * h * w
49
+ x_i = x[0, :seq_len, :, :]
50
+
51
+ x_i = x_i.to(torch.float32)
52
+ x_i = x_i.reshape(seq_len, n, -1, 2)
53
+ x_i = torch.view_as_complex(x_i)
54
+ freqs_i = torch.cat([
55
+ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
56
+ freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
57
+ freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
58
+ ], dim=-1)
59
+ freqs_i= freqs_i.reshape(seq_len, 1, -1)
60
+
61
+ # apply rotary embedding
62
+ x_i *= freqs_i
63
+ x_i = torch.view_as_real(x_i).flatten(2)
64
+ x[0, :seq_len, :, :] = x_i.to(torch.bfloat16)
65
+ # x_i = torch.cat([x_i, x[0, seq_len:]])
66
+ return x
67
+
68
+ # @amp.autocast(enabled=False)
69
  def rope_apply(x, grid_sizes, freqs):
70
  n, c = x.size(2), x.size(3) // 2
71
 
 
74
 
75
  # loop over samples
76
  output = []
77
+ for i, (f, h, w) in enumerate(grid_sizes):
78
  seq_len = f * h * w
79
 
80
  # precompute multipliers
81
+ # x_i = x[i, :seq_len]
82
+ x_i = x[i]
83
+ x_i = x_i[:seq_len, :, :]
84
+
85
+ x_i = x_i.to(torch.float32)
86
+ x_i = x_i.reshape(seq_len, n, -1, 2)
87
+ x_i = torch.view_as_complex(x_i)
88
  freqs_i = torch.cat([
89
  freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
90
  freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
 
93
  dim=-1).reshape(seq_len, 1, -1)
94
 
95
  # apply rotary embedding
96
+ x_i *= freqs_i
97
+ x_i = torch.view_as_real(x_i).flatten(2)
98
+ x_i = x_i.to(torch.bfloat16)
99
  x_i = torch.cat([x_i, x[i, seq_len:]])
100
 
101
  # append to collection
102
  output.append(x_i)
103
+ return torch.stack(output) #.float()
104
 
105
 
106
  class WanRMSNorm(nn.Module):
 
116
  Args:
117
  x(Tensor): Shape [B, L, C]
118
  """
119
+ y = x.float()
120
+ y.pow_(2)
121
+ y = y.mean(dim=-1, keepdim=True)
122
+ y += self.eps
123
+ y.rsqrt_()
124
+ x *= y
125
+ x *= self.weight
126
+ return x
127
+ # return self._norm(x).type_as(x) * self.weight
128
 
129
  def _norm(self, x):
130
  return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
131
 
132
+ def my_LayerNorm(norm, x):
133
+ y = x.float()
134
+ y_m = y.mean(dim=-1, keepdim=True)
135
+ y -= y_m
136
+ del y_m
137
+ y.pow_(2)
138
+ y = y.mean(dim=-1, keepdim=True)
139
+ y += norm.eps
140
+ y.rsqrt_()
141
+ x = x * y
142
+ return x
143
+
144
 
145
  class WanLayerNorm(nn.LayerNorm):
146
 
 
152
  Args:
153
  x(Tensor): Shape [B, L, C]
154
  """
155
+ # return F.layer_norm(
156
+ # input, self.normalized_shape, self.weight, self.bias, self.eps
157
+ # )
158
+ y = super().forward(x)
159
+ x = y.type_as(x)
160
+ return x
161
+ # return super().forward(x).type_as(x)
162
 
163
 
164
  class WanSelfAttention(nn.Module):
 
186
  self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
187
  self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
188
 
189
+ def forward(self, xlist, seq_lens, grid_sizes, freqs):
190
  r"""
191
  Args:
192
  x(Tensor): Shape [B, L, num_heads, C / num_heads]
 
194
  grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
195
  freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
196
  """
197
+ x = xlist[0]
198
+ xlist.clear()
199
+
200
  b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
201
 
202
  # query, key, value function
203
+ q = self.q(x)
204
+ self.norm_q(q)
205
+ q = q.view(b, s, n, d) # !!!
206
+ k = self.k(x)
207
+ self.norm_k(k)
208
+ k = k.view(b, s, n, d)
209
+ v = self.v(x).view(b, s, n, d)
210
+ del x
211
+ rope_apply_(q, grid_sizes, freqs)
212
+ rope_apply_(k, grid_sizes, freqs)
213
+ qkv_list = [q,k,v]
214
+ del q,k,v
215
+ x = pay_attention(
216
+ qkv_list,
217
+ # q=q,
218
+ # k=k,
219
+ # v=v,
220
+ # k_lens=seq_lens,
221
  window_size=self.window_size)
 
222
  # output
223
  x = x.flatten(2)
224
  x = self.o(x)
 
227
 
228
  class WanT2VCrossAttention(WanSelfAttention):
229
 
230
+ def forward(self, xlist, context, context_lens):
231
  r"""
232
  Args:
233
  x(Tensor): Shape [B, L1, C]
234
  context(Tensor): Shape [B, L2, C]
235
  context_lens(Tensor): Shape [B]
236
  """
237
+ x = xlist[0]
238
+ xlist.clear()
239
  b, n, d = x.size(0), self.num_heads, self.head_dim
240
 
241
  # compute query, key, value
242
+ q = self.q(x)
243
+ del x
244
+ self.norm_q(q)
245
+ q= q.view(b, -1, n, d)
246
+ k = self.k(context)
247
+ self.norm_k(k)
248
+ k = k.view(b, -1, n, d)
249
  v = self.v(context).view(b, -1, n, d)
250
 
251
  # compute attention
252
+ qvl_list=[q, k, v]
253
+ del q, k, v
254
+ x = pay_attention(qvl_list, k_lens=context_lens)
255
 
256
  # output
257
  x = x.flatten(2)
 
274
  # self.alpha = nn.Parameter(torch.zeros((1, )))
275
  self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
276
 
277
+ def forward(self, xlist, context, context_lens):
278
  r"""
279
  Args:
280
  x(Tensor): Shape [B, L1, C]
281
  context(Tensor): Shape [B, L2, C]
282
  context_lens(Tensor): Shape [B]
283
  """
284
+
285
+ ##### Enjoy this spagheti VRAM optimizations done by DeepBeepMeep !
286
+ # I am sure you are a nice person and as you copy this code, you will give me officially proper credits:
287
+ # Please link to https://github.com/deepbeepmeep/Wan2GP and @deepbeepmeep on twitter
288
+
289
+ x = xlist[0]
290
+ xlist.clear()
291
+
292
  context_img = context[:, :257]
293
  context = context[:, 257:]
294
  b, n, d = x.size(0), self.num_heads, self.head_dim
295
 
296
  # compute query, key, value
297
+ q = self.q(x)
298
+ del x
299
+ self.norm_q(q)
300
+ q= q.view(b, -1, n, d)
301
+ k = self.k(context)
302
+ self.norm_k(k)
303
+ k = k.view(b, -1, n, d)
304
  v = self.v(context).view(b, -1, n, d)
305
+
306
+ qkv_list = [q, k, v]
307
+ del k,v
308
+ x = pay_attention(qkv_list, k_lens=context_lens)
309
+
310
+ k_img = self.k_img(context_img)
311
+ self.norm_k_img(k_img)
312
+ k_img = k_img.view(b, -1, n, d)
313
  v_img = self.v_img(context_img).view(b, -1, n, d)
314
+ qkv_list = [q, k_img, v_img]
315
+ del q, k_img, v_img
316
+ img_x = pay_attention(qkv_list, k_lens=None)
317
  # compute attention
318
+
319
 
320
  # output
321
  x = x.flatten(2)
322
  img_x = img_x.flatten(2)
323
+ x += img_x
324
+ del img_x
325
  x = self.o(x)
326
  return x
327
 
 
390
  grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
391
  freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
392
  """
393
+ e = (self.modulation + e).chunk(6, dim=1)
394
+
 
 
 
395
  # self-attention
396
+ x_mod = self.norm1(x)
397
+ x_mod *= 1 + e[1]
398
+ x_mod += e[0]
399
+ xlist = [x_mod]
400
+ del x_mod
401
+ y = self.self_attn( xlist, seq_lens, grid_sizes,freqs)
402
+ x.addcmul_(y, e[2])
403
+ del y
404
+ y = self.norm3(x)
405
+ ylist= [y]
406
+ del y
407
+ x += self.cross_attn(ylist, context, context_lens)
408
+ y = self.norm2(x)
409
+
410
+ y *= 1 + e[4]
411
+ y += e[3]
412
+
413
+
414
+ ffn = self.ffn[0]
415
+ gelu = self.ffn[1]
416
+ ffn2= self.ffn[2]
417
+
418
+ y_shape = y.shape
419
+ y = y.view(-1, y_shape[-1])
420
+ chunk_size = int(y_shape[1]/2.7)
421
+ chunks =torch.split(y, chunk_size)
422
+ for y_chunk in chunks:
423
+ mlp_chunk = ffn(y_chunk)
424
+ mlp_chunk = gelu(mlp_chunk)
425
+ y_chunk[...] = ffn2(mlp_chunk)
426
+ del mlp_chunk
427
+ y = y.view(y_shape)
428
+
429
+ x.addcmul_(y, e[5])
430
+
431
+
432
+
433
  return x
434
 
435
 
 
456
  x(Tensor): Shape [B, L1, C]
457
  e(Tensor): Shape [B, C]
458
  """
459
+ # assert e.dtype == torch.float32
460
+
461
+ e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
462
+ x = self.norm(x).to(torch.bfloat16)
463
+ x *= (1 + e[1])
464
+ x += e[0]
465
+ x = self.head(x)
466
  return x
467
 
468
 
 
507
  window_size=(-1, -1),
508
  qk_norm=True,
509
  cross_attn_norm=True,
510
+ eps=1e-6,
511
+ ):
512
  r"""
513
  Initialize the diffusion model backbone.
514
 
 
590
  # buffers (don't use register_buffer otherwise dtype will be changed in to())
591
  assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
592
  d = dim // num_heads
593
+ self.freqs = torch.cat([
594
  rope_params(1024, d - 4 * (d // 6)),
595
  rope_params(1024, 2 * (d // 6)),
596
  rope_params(1024, 2 * (d // 6))
 
611
  seq_len,
612
  clip_fea=None,
613
  y=None,
614
+ pipeline = None,
615
  ):
616
  r"""
617
  Forward pass through the diffusion model
 
646
 
647
  # embeddings
648
  x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
649
+ # grid_sizes = torch.stack(
650
+ # [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
651
+
652
+ grid_sizes = [ list(u.shape[2:]) for u in x]
653
+
654
  x = [u.flatten(2).transpose(1, 2) for u in x]
655
  seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
656
  assert seq_lens.max() <= seq_len
 
660
  ])
661
 
662
  # time embeddings
663
+ e = self.time_embedding(
664
+ sinusoidal_embedding_1d(self.freq_dim, t))
665
+ e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(torch.bfloat16)
666
+ # assert e.dtype == torch.float32 and e0.dtype == torch.float32
 
667
 
668
  # context
669
  context_lens = None
 
688
  context_lens=context_lens)
689
 
690
  for block in self.blocks:
691
+ if pipeline._interrupt:
692
+ return [None]
693
+
694
  x = block(x, **kwargs)
695
 
696
  # head
 
718
 
719
  c = self.out_dim
720
  out = []
721
+ for u, v in zip(x, grid_sizes):
722
  u = u[:math.prod(v)].view(*v, *self.patch_size, c)
723
  u = torch.einsum('fhwpqrc->cfphqwr', u)
724
  u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
wan/modules/t5.py CHANGED
@@ -442,7 +442,7 @@ def _t5(name,
442
  model = model_cls(**kwargs)
443
 
444
  # set device
445
- model = model.to(dtype=dtype, device=device)
446
 
447
  # init tokenizer
448
  if return_tokenizer:
@@ -486,20 +486,25 @@ class T5EncoderModel:
486
  self.checkpoint_path = checkpoint_path
487
  self.tokenizer_path = tokenizer_path
488
 
 
489
  # init model
490
- model = umt5_xxl(
491
- encoder_only=True,
492
- return_tokenizer=False,
493
- dtype=dtype,
494
- device=device).eval().requires_grad_(False)
 
495
  logging.info(f'loading {checkpoint_path}')
496
- model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
 
 
497
  self.model = model
498
  if shard_fn is not None:
499
  self.model = shard_fn(self.model, sync_module_states=False)
500
  else:
501
  self.model.to(self.device)
502
  # init tokenizer
 
503
  self.tokenizer = HuggingfaceTokenizer(
504
  name=tokenizer_path, seq_len=text_len, clean='whitespace')
505
 
 
442
  model = model_cls(**kwargs)
443
 
444
  # set device
445
+ # model = model.to(dtype=dtype, device=device)
446
 
447
  # init tokenizer
448
  if return_tokenizer:
 
486
  self.checkpoint_path = checkpoint_path
487
  self.tokenizer_path = tokenizer_path
488
 
489
+ from accelerate import init_empty_weights
490
  # init model
491
+ with init_empty_weights():
492
+ model = umt5_xxl(
493
+ encoder_only=True,
494
+ return_tokenizer=False,
495
+ dtype=dtype,
496
+ device=device).eval().requires_grad_(False)
497
  logging.info(f'loading {checkpoint_path}')
498
+ from mmgp import offload
499
+ offload.load_model_data(model,checkpoint_path )
500
+
501
  self.model = model
502
  if shard_fn is not None:
503
  self.model = shard_fn(self.model, sync_module_states=False)
504
  else:
505
  self.model.to(self.device)
506
  # init tokenizer
507
+ tokenizer_path= "google/umt5-xxl"
508
  self.tokenizer = HuggingfaceTokenizer(
509
  name=tokenizer_path, seq_len=text_len, clean='whitespace')
510
 
wan/modules/vae.py CHANGED
@@ -1,6 +1,6 @@
1
  # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
  import logging
3
-
4
  import torch
5
  import torch.cuda.amp as amp
6
  import torch.nn as nn
@@ -31,9 +31,16 @@ class CausalConv3d(nn.Conv3d):
31
  cache_x = cache_x.to(x.device)
32
  x = torch.cat([cache_x, x], dim=2)
33
  padding[4] -= cache_x.shape[2]
 
34
  x = F.pad(x, padding)
 
35
 
36
- return super().forward(x)
 
 
 
 
 
37
 
38
 
39
  class RMS_norm(nn.Module):
@@ -49,10 +56,11 @@ class RMS_norm(nn.Module):
49
  self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
50
 
51
  def forward(self, x):
52
- return F.normalize(
53
  x, dim=(1 if self.channel_first else
54
  -1)) * self.scale * self.gamma + self.bias
55
-
 
56
 
57
  class Upsample(nn.Upsample):
58
 
@@ -107,11 +115,12 @@ class Resample(nn.Module):
107
  feat_cache[idx] = 'Rep'
108
  feat_idx[0] += 1
109
  else:
110
-
111
- cache_x = x[:, :, -CACHE_T:, :, :].clone()
112
  if cache_x.shape[2] < 2 and feat_cache[
113
  idx] is not None and feat_cache[idx] != 'Rep':
114
  # cache last frame of last two chunk
 
115
  cache_x = torch.cat([
116
  feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
117
  cache_x.device), cache_x
@@ -119,11 +128,14 @@ class Resample(nn.Module):
119
  dim=2)
120
  if cache_x.shape[2] < 2 and feat_cache[
121
  idx] is not None and feat_cache[idx] == 'Rep':
 
122
  cache_x = torch.cat([
123
  torch.zeros_like(cache_x).to(cache_x.device),
124
  cache_x
125
  ],
126
  dim=2)
 
 
127
  if feat_cache[idx] == 'Rep':
128
  x = self.time_conv(x)
129
  else:
@@ -144,7 +156,7 @@ class Resample(nn.Module):
144
  if feat_cache is not None:
145
  idx = feat_idx[0]
146
  if feat_cache[idx] is None:
147
- feat_cache[idx] = x.clone()
148
  feat_idx[0] += 1
149
  else:
150
 
@@ -155,7 +167,7 @@ class Resample(nn.Module):
155
 
156
  x = self.time_conv(
157
  torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
158
- feat_cache[idx] = cache_x
159
  feat_idx[0] += 1
160
  return x
161
 
@@ -212,11 +224,11 @@ class ResidualBlock(nn.Module):
212
  cache_x.device), cache_x
213
  ],
214
  dim=2)
215
- x = layer(x, feat_cache[idx])
216
- feat_cache[idx] = cache_x
217
  feat_idx[0] += 1
218
  else:
219
- x = layer(x)
220
  return x + h
221
 
222
 
@@ -326,12 +338,16 @@ class Encoder3d(nn.Module):
326
  cache_x.device), cache_x
327
  ],
328
  dim=2)
329
- x = self.conv1(x, feat_cache[idx])
330
  feat_cache[idx] = cache_x
 
331
  feat_idx[0] += 1
332
  else:
333
  x = self.conv1(x)
334
 
 
 
 
335
  ## downsamples
336
  for layer in self.downsamples:
337
  if feat_cache is not None:
@@ -339,6 +355,8 @@ class Encoder3d(nn.Module):
339
  else:
340
  x = layer(x)
341
 
 
 
342
  ## middle
343
  for layer in self.middle:
344
  if isinstance(layer, ResidualBlock) and feat_cache is not None:
@@ -346,6 +364,8 @@ class Encoder3d(nn.Module):
346
  else:
347
  x = layer(x)
348
 
 
 
349
  ## head
350
  for layer in self.head:
351
  if isinstance(layer, CausalConv3d) and feat_cache is not None:
@@ -360,9 +380,13 @@ class Encoder3d(nn.Module):
360
  dim=2)
361
  x = layer(x, feat_cache[idx])
362
  feat_cache[idx] = cache_x
 
363
  feat_idx[0] += 1
364
  else:
365
  x = layer(x)
 
 
 
366
  return x
367
 
368
 
@@ -433,10 +457,12 @@ class Decoder3d(nn.Module):
433
  ],
434
  dim=2)
435
  x = self.conv1(x, feat_cache[idx])
436
- feat_cache[idx] = cache_x
 
437
  feat_idx[0] += 1
438
  else:
439
  x = self.conv1(x)
 
440
 
441
  ## middle
442
  for layer in self.middle:
@@ -456,7 +482,7 @@ class Decoder3d(nn.Module):
456
  for layer in self.head:
457
  if isinstance(layer, CausalConv3d) and feat_cache is not None:
458
  idx = feat_idx[0]
459
- cache_x = x[:, :, -CACHE_T:, :, :].clone()
460
  if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
461
  # cache last frame of last two chunk
462
  cache_x = torch.cat([
@@ -465,7 +491,8 @@ class Decoder3d(nn.Module):
465
  ],
466
  dim=2)
467
  x = layer(x, feat_cache[idx])
468
- feat_cache[idx] = cache_x
 
469
  feat_idx[0] += 1
470
  else:
471
  x = layer(x)
@@ -532,6 +559,8 @@ class WanVAE_(nn.Module):
532
  feat_cache=self._enc_feat_map,
533
  feat_idx=self._enc_conv_idx)
534
  out = torch.cat([out, out_], 2)
 
 
535
  mu, log_var = self.conv1(out).chunk(2, dim=1)
536
  if isinstance(scale[0], torch.Tensor):
537
  mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
 
1
  # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
  import logging
3
+ from mmgp import offload
4
  import torch
5
  import torch.cuda.amp as amp
6
  import torch.nn as nn
 
31
  cache_x = cache_x.to(x.device)
32
  x = torch.cat([cache_x, x], dim=2)
33
  padding[4] -= cache_x.shape[2]
34
+ cache_x = None
35
  x = F.pad(x, padding)
36
+ x = super().forward(x)
37
 
38
+ mem_threshold = offload.shared_state.get("_vae_threshold",0)
39
+ vae_config = offload.shared_state.get("_vae",1)
40
+
41
+ if vae_config == 0 and torch.cuda.memory_reserved() > mem_threshold or vae_config == 2:
42
+ torch.cuda.empty_cache()
43
+ return x
44
 
45
 
46
  class RMS_norm(nn.Module):
 
56
  self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
57
 
58
  def forward(self, x):
59
+ x = F.normalize(
60
  x, dim=(1 if self.channel_first else
61
  -1)) * self.scale * self.gamma + self.bias
62
+ x = x.to(torch.bfloat16)
63
+ return x
64
 
65
  class Upsample(nn.Upsample):
66
 
 
115
  feat_cache[idx] = 'Rep'
116
  feat_idx[0] += 1
117
  else:
118
+ clone = True
119
+ cache_x = x[:, :, -CACHE_T:, :, :]#.clone()
120
  if cache_x.shape[2] < 2 and feat_cache[
121
  idx] is not None and feat_cache[idx] != 'Rep':
122
  # cache last frame of last two chunk
123
+ clone = False
124
  cache_x = torch.cat([
125
  feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
126
  cache_x.device), cache_x
 
128
  dim=2)
129
  if cache_x.shape[2] < 2 and feat_cache[
130
  idx] is not None and feat_cache[idx] == 'Rep':
131
+ clone = False
132
  cache_x = torch.cat([
133
  torch.zeros_like(cache_x).to(cache_x.device),
134
  cache_x
135
  ],
136
  dim=2)
137
+ if clone:
138
+ cache_x = cache_x.clone()
139
  if feat_cache[idx] == 'Rep':
140
  x = self.time_conv(x)
141
  else:
 
156
  if feat_cache is not None:
157
  idx = feat_idx[0]
158
  if feat_cache[idx] is None:
159
+ feat_cache[idx] = x #.to("cpu") #x.clone() yyyy
160
  feat_idx[0] += 1
161
  else:
162
 
 
167
 
168
  x = self.time_conv(
169
  torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
170
+ feat_cache[idx] = cache_x#.to("cpu") #yyyyy
171
  feat_idx[0] += 1
172
  return x
173
 
 
224
  cache_x.device), cache_x
225
  ],
226
  dim=2)
227
+ x = layer(x, feat_cache[idx]).to(torch.bfloat16)
228
+ feat_cache[idx] = cache_x#.to("cpu")
229
  feat_idx[0] += 1
230
  else:
231
+ x = layer(x).to(torch.bfloat16)
232
  return x + h
233
 
234
 
 
338
  cache_x.device), cache_x
339
  ],
340
  dim=2)
341
+ x = self.conv1(x, feat_cache[idx]).to(torch.bfloat16)
342
  feat_cache[idx] = cache_x
343
+ del cache_x
344
  feat_idx[0] += 1
345
  else:
346
  x = self.conv1(x)
347
 
348
+
349
+ # torch.cuda.empty_cache()
350
+
351
  ## downsamples
352
  for layer in self.downsamples:
353
  if feat_cache is not None:
 
355
  else:
356
  x = layer(x)
357
 
358
+ # torch.cuda.empty_cache()
359
+
360
  ## middle
361
  for layer in self.middle:
362
  if isinstance(layer, ResidualBlock) and feat_cache is not None:
 
364
  else:
365
  x = layer(x)
366
 
367
+ # torch.cuda.empty_cache()
368
+
369
  ## head
370
  for layer in self.head:
371
  if isinstance(layer, CausalConv3d) and feat_cache is not None:
 
380
  dim=2)
381
  x = layer(x, feat_cache[idx])
382
  feat_cache[idx] = cache_x
383
+ del cache_x
384
  feat_idx[0] += 1
385
  else:
386
  x = layer(x)
387
+
388
+ # torch.cuda.empty_cache()
389
+
390
  return x
391
 
392
 
 
457
  ],
458
  dim=2)
459
  x = self.conv1(x, feat_cache[idx])
460
+ feat_cache[idx] = cache_x#.to("cpu")
461
+ del cache_x
462
  feat_idx[0] += 1
463
  else:
464
  x = self.conv1(x)
465
+ cache_x = None
466
 
467
  ## middle
468
  for layer in self.middle:
 
482
  for layer in self.head:
483
  if isinstance(layer, CausalConv3d) and feat_cache is not None:
484
  idx = feat_idx[0]
485
+ cache_x = x[:, :, -CACHE_T:, :, :] .clone()
486
  if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
487
  # cache last frame of last two chunk
488
  cache_x = torch.cat([
 
491
  ],
492
  dim=2)
493
  x = layer(x, feat_cache[idx])
494
+ feat_cache[idx] = cache_x#.to("cpu")
495
+ del cache_x
496
  feat_idx[0] += 1
497
  else:
498
  x = layer(x)
 
559
  feat_cache=self._enc_feat_map,
560
  feat_idx=self._enc_conv_idx)
561
  out = torch.cat([out, out_], 2)
562
+
563
+
564
  mu, log_var = self.conv1(out).chunk(2, dim=1)
565
  if isinstance(scale[0], torch.Tensor):
566
  mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
wan/text2video.py CHANGED
@@ -35,6 +35,8 @@ class WanT2V:
35
  dit_fsdp=False,
36
  use_usp=False,
37
  t5_cpu=False,
 
 
38
  ):
39
  r"""
40
  Initializes the Wan text-to-video generation model components.
@@ -70,18 +72,26 @@ class WanT2V:
70
  text_len=config.text_len,
71
  dtype=config.t5_dtype,
72
  device=torch.device('cpu'),
73
- checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
74
  tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
75
  shard_fn=shard_fn if t5_fsdp else None)
76
 
77
  self.vae_stride = config.vae_stride
78
  self.patch_size = config.patch_size
 
 
79
  self.vae = WanVAE(
80
  vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
81
  device=self.device)
82
 
83
- logging.info(f"Creating WanModel from {checkpoint_dir}")
84
- self.model = WanModel.from_pretrained(checkpoint_dir)
 
 
 
 
 
 
85
  self.model.eval().requires_grad_(False)
86
 
87
  if use_usp:
@@ -98,12 +108,12 @@ class WanT2V:
98
  else:
99
  self.sp_size = 1
100
 
101
- if dist.is_initialized():
102
- dist.barrier()
103
- if dit_fsdp:
104
- self.model = shard_fn(self.model)
105
- else:
106
- self.model.to(self.device)
107
 
108
  self.sample_neg_prompt = config.sample_neg_prompt
109
 
@@ -117,7 +127,9 @@ class WanT2V:
117
  guide_scale=5.0,
118
  n_prompt="",
119
  seed=-1,
120
- offload_model=True):
 
 
121
  r"""
122
  Generates video frames from text prompt using diffusion process.
123
 
@@ -168,7 +180,7 @@ class WanT2V:
168
  seed_g.manual_seed(seed)
169
 
170
  if not self.t5_cpu:
171
- self.text_encoder.model.to(self.device)
172
  context = self.text_encoder([input_prompt], self.device)
173
  context_null = self.text_encoder([n_prompt], self.device)
174
  if offload_model:
@@ -223,23 +235,32 @@ class WanT2V:
223
  # sample videos
224
  latents = noise
225
 
226
- arg_c = {'context': context, 'seq_len': seq_len}
227
- arg_null = {'context': context_null, 'seq_len': seq_len}
228
 
229
- for _, t in enumerate(tqdm(timesteps)):
 
 
 
230
  latent_model_input = latents
231
  timestep = [t]
232
 
233
  timestep = torch.stack(timestep)
234
 
235
- self.model.to(self.device)
236
  noise_pred_cond = self.model(
237
  latent_model_input, t=timestep, **arg_c)[0]
 
 
238
  noise_pred_uncond = self.model(
239
  latent_model_input, t=timestep, **arg_null)[0]
 
 
240
 
 
241
  noise_pred = noise_pred_uncond + guide_scale * (
242
  noise_pred_cond - noise_pred_uncond)
 
243
 
244
  temp_x0 = sample_scheduler.step(
245
  noise_pred.unsqueeze(0),
@@ -248,6 +269,10 @@ class WanT2V:
248
  return_dict=False,
249
  generator=seed_g)[0]
250
  latents = [temp_x0.squeeze(0)]
 
 
 
 
251
 
252
  x0 = latents
253
  if offload_model:
@@ -256,6 +281,7 @@ class WanT2V:
256
  if self.rank == 0:
257
  videos = self.vae.decode(x0)
258
 
 
259
  del noise, latents
260
  del sample_scheduler
261
  if offload_model:
 
35
  dit_fsdp=False,
36
  use_usp=False,
37
  t5_cpu=False,
38
+ model_filename = None,
39
+ text_encoder_filename = None
40
  ):
41
  r"""
42
  Initializes the Wan text-to-video generation model components.
 
72
  text_len=config.text_len,
73
  dtype=config.t5_dtype,
74
  device=torch.device('cpu'),
75
+ checkpoint_path=text_encoder_filename,
76
  tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
77
  shard_fn=shard_fn if t5_fsdp else None)
78
 
79
  self.vae_stride = config.vae_stride
80
  self.patch_size = config.patch_size
81
+
82
+
83
  self.vae = WanVAE(
84
  vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
85
  device=self.device)
86
 
87
+ logging.info(f"Creating WanModel from {model_filename}")
88
+ from mmgp import offload
89
+
90
+
91
+ self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel)
92
+
93
+
94
+
95
  self.model.eval().requires_grad_(False)
96
 
97
  if use_usp:
 
108
  else:
109
  self.sp_size = 1
110
 
111
+ # if dist.is_initialized():
112
+ # dist.barrier()
113
+ # if dit_fsdp:
114
+ # self.model = shard_fn(self.model)
115
+ # else:
116
+ # self.model.to(self.device)
117
 
118
  self.sample_neg_prompt = config.sample_neg_prompt
119
 
 
127
  guide_scale=5.0,
128
  n_prompt="",
129
  seed=-1,
130
+ offload_model=True,
131
+ callback = None
132
+ ):
133
  r"""
134
  Generates video frames from text prompt using diffusion process.
135
 
 
180
  seed_g.manual_seed(seed)
181
 
182
  if not self.t5_cpu:
183
+ # self.text_encoder.model.to(self.device)
184
  context = self.text_encoder([input_prompt], self.device)
185
  context_null = self.text_encoder([n_prompt], self.device)
186
  if offload_model:
 
235
  # sample videos
236
  latents = noise
237
 
238
+ arg_c = {'context': context, 'seq_len': seq_len, 'pipeline': self}
239
+ arg_null = {'context': context_null, 'seq_len': seq_len, 'pipeline': self}
240
 
241
+ if callback != None:
242
+ callback(-1, None)
243
+ self._interrupt = False
244
+ for i, t in enumerate(tqdm(timesteps)):
245
  latent_model_input = latents
246
  timestep = [t]
247
 
248
  timestep = torch.stack(timestep)
249
 
250
+ # self.model.to(self.device)
251
  noise_pred_cond = self.model(
252
  latent_model_input, t=timestep, **arg_c)[0]
253
+ if self._interrupt:
254
+ return None
255
  noise_pred_uncond = self.model(
256
  latent_model_input, t=timestep, **arg_null)[0]
257
+ if self._interrupt:
258
+ return None
259
 
260
+ del latent_model_input
261
  noise_pred = noise_pred_uncond + guide_scale * (
262
  noise_pred_cond - noise_pred_uncond)
263
+ del noise_pred_uncond
264
 
265
  temp_x0 = sample_scheduler.step(
266
  noise_pred.unsqueeze(0),
 
269
  return_dict=False,
270
  generator=seed_g)[0]
271
  latents = [temp_x0.squeeze(0)]
272
+ del temp_x0
273
+
274
+ if callback is not None:
275
+ callback(i, latents)
276
 
277
  x0 = latents
278
  if offload_model:
 
281
  if self.rank == 0:
282
  videos = self.vae.decode(x0)
283
 
284
+
285
  del noise, latents
286
  del sample_scheduler
287
  if offload_model: