Spaces:
Running
on
T4
Running
on
T4
DeepBeepMeep
commited on
Commit
·
59880dc
1
Parent(s):
172584e
beta version
Browse files- LICENSE.txt +12 -196
- README.md +95 -314
- gradio/i2v_14B_singleGPU.py +27 -7
- gradio/t2i_14B_singleGPU.py +1 -0
- gradio/t2v_14B_singleGPU.py +23 -12
- gradio_server.py +1275 -0
- loras/README.txt +1 -0
- loras_i2v/README.txt +1 -0
- requirements.txt +4 -1
- wan/image2video.py +67 -42
- wan/modules/__init__.py +2 -2
- wan/modules/attention.py +116 -9
- wan/modules/clip.py +15 -11
- wan/modules/model.py +202 -72
- wan/modules/t5.py +12 -7
- wan/modules/vae.py +44 -15
- wan/text2video.py +41 -15
LICENSE.txt
CHANGED
@@ -1,201 +1,17 @@
|
|
1 |
-
|
2 |
-
Version 2.0, January 2004
|
3 |
-
http://www.apache.org/licenses/
|
4 |
|
5 |
-
|
|
|
|
|
|
|
6 |
|
7 |
-
|
|
|
|
|
8 |
|
9 |
-
|
10 |
-
|
11 |
|
12 |
-
|
13 |
-
the copyright owner that is granting the License.
|
14 |
|
15 |
-
|
16 |
-
other entities that control, are controlled by, or are under common
|
17 |
-
control with that entity. For the purposes of this definition,
|
18 |
-
"control" means (i) the power, direct or indirect, to cause the
|
19 |
-
direction or management of such entity, whether by contract or
|
20 |
-
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
-
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
-
|
23 |
-
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
-
exercising permissions granted by this License.
|
25 |
-
|
26 |
-
"Source" form shall mean the preferred form for making modifications,
|
27 |
-
including but not limited to software source code, documentation
|
28 |
-
source, and configuration files.
|
29 |
-
|
30 |
-
"Object" form shall mean any form resulting from mechanical
|
31 |
-
transformation or translation of a Source form, including but
|
32 |
-
not limited to compiled object code, generated documentation,
|
33 |
-
and conversions to other media types.
|
34 |
-
|
35 |
-
"Work" shall mean the work of authorship, whether in Source or
|
36 |
-
Object form, made available under the License, as indicated by a
|
37 |
-
copyright notice that is included in or attached to the work
|
38 |
-
(an example is provided in the Appendix below).
|
39 |
-
|
40 |
-
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
-
form, that is based on (or derived from) the Work and for which the
|
42 |
-
editorial revisions, annotations, elaborations, or other modifications
|
43 |
-
represent, as a whole, an original work of authorship. For the purposes
|
44 |
-
of this License, Derivative Works shall not include works that remain
|
45 |
-
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
-
the Work and Derivative Works thereof.
|
47 |
-
|
48 |
-
"Contribution" shall mean any work of authorship, including
|
49 |
-
the original version of the Work and any modifications or additions
|
50 |
-
to that Work or Derivative Works thereof, that is intentionally
|
51 |
-
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
-
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
-
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
-
means any form of electronic, verbal, or written communication sent
|
55 |
-
to the Licensor or its representatives, including but not limited to
|
56 |
-
communication on electronic mailing lists, source code control systems,
|
57 |
-
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
-
Licensor for the purpose of discussing and improving the Work, but
|
59 |
-
excluding communication that is conspicuously marked or otherwise
|
60 |
-
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
-
|
62 |
-
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
-
on behalf of whom a Contribution has been received by Licensor and
|
64 |
-
subsequently incorporated within the Work.
|
65 |
-
|
66 |
-
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
-
this License, each Contributor hereby grants to You a perpetual,
|
68 |
-
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
-
copyright license to reproduce, prepare Derivative Works of,
|
70 |
-
publicly display, publicly perform, sublicense, and distribute the
|
71 |
-
Work and such Derivative Works in Source or Object form.
|
72 |
-
|
73 |
-
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
-
this License, each Contributor hereby grants to You a perpetual,
|
75 |
-
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
-
(except as stated in this section) patent license to make, have made,
|
77 |
-
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
-
where such license applies only to those patent claims licensable
|
79 |
-
by such Contributor that are necessarily infringed by their
|
80 |
-
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
-
with the Work to which such Contribution(s) was submitted. If You
|
82 |
-
institute patent litigation against any entity (including a
|
83 |
-
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
-
or a Contribution incorporated within the Work constitutes direct
|
85 |
-
or contributory patent infringement, then any patent licenses
|
86 |
-
granted to You under this License for that Work shall terminate
|
87 |
-
as of the date such litigation is filed.
|
88 |
-
|
89 |
-
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
-
Work or Derivative Works thereof in any medium, with or without
|
91 |
-
modifications, and in Source or Object form, provided that You
|
92 |
-
meet the following conditions:
|
93 |
-
|
94 |
-
(a) You must give any other recipients of the Work or
|
95 |
-
Derivative Works a copy of this License; and
|
96 |
-
|
97 |
-
(b) You must cause any modified files to carry prominent notices
|
98 |
-
stating that You changed the files; and
|
99 |
-
|
100 |
-
(c) You must retain, in the Source form of any Derivative Works
|
101 |
-
that You distribute, all copyright, patent, trademark, and
|
102 |
-
attribution notices from the Source form of the Work,
|
103 |
-
excluding those notices that do not pertain to any part of
|
104 |
-
the Derivative Works; and
|
105 |
-
|
106 |
-
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
-
distribution, then any Derivative Works that You distribute must
|
108 |
-
include a readable copy of the attribution notices contained
|
109 |
-
within such NOTICE file, excluding those notices that do not
|
110 |
-
pertain to any part of the Derivative Works, in at least one
|
111 |
-
of the following places: within a NOTICE text file distributed
|
112 |
-
as part of the Derivative Works; within the Source form or
|
113 |
-
documentation, if provided along with the Derivative Works; or,
|
114 |
-
within a display generated by the Derivative Works, if and
|
115 |
-
wherever such third-party notices normally appear. The contents
|
116 |
-
of the NOTICE file are for informational purposes only and
|
117 |
-
do not modify the License. You may add Your own attribution
|
118 |
-
notices within Derivative Works that You distribute, alongside
|
119 |
-
or as an addendum to the NOTICE text from the Work, provided
|
120 |
-
that such additional attribution notices cannot be construed
|
121 |
-
as modifying the License.
|
122 |
-
|
123 |
-
You may add Your own copyright statement to Your modifications and
|
124 |
-
may provide additional or different license terms and conditions
|
125 |
-
for use, reproduction, or distribution of Your modifications, or
|
126 |
-
for any such Derivative Works as a whole, provided Your use,
|
127 |
-
reproduction, and distribution of the Work otherwise complies with
|
128 |
-
the conditions stated in this License.
|
129 |
-
|
130 |
-
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
-
any Contribution intentionally submitted for inclusion in the Work
|
132 |
-
by You to the Licensor shall be under the terms and conditions of
|
133 |
-
this License, without any additional terms or conditions.
|
134 |
-
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
-
the terms of any separate license agreement you may have executed
|
136 |
-
with Licensor regarding such Contributions.
|
137 |
-
|
138 |
-
6. Trademarks. This License does not grant permission to use the trade
|
139 |
-
names, trademarks, service marks, or product names of the Licensor,
|
140 |
-
except as required for reasonable and customary use in describing the
|
141 |
-
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
-
|
143 |
-
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
-
agreed to in writing, Licensor provides the Work (and each
|
145 |
-
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
-
implied, including, without limitation, any warranties or conditions
|
148 |
-
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
-
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
-
appropriateness of using or redistributing the Work and assume any
|
151 |
-
risks associated with Your exercise of permissions under this License.
|
152 |
-
|
153 |
-
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
-
whether in tort (including negligence), contract, or otherwise,
|
155 |
-
unless required by applicable law (such as deliberate and grossly
|
156 |
-
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
-
liable to You for damages, including any direct, indirect, special,
|
158 |
-
incidental, or consequential damages of any character arising as a
|
159 |
-
result of this License or out of the use or inability to use the
|
160 |
-
Work (including but not limited to damages for loss of goodwill,
|
161 |
-
work stoppage, computer failure or malfunction, or any and all
|
162 |
-
other commercial damages or losses), even if such Contributor
|
163 |
-
has been advised of the possibility of such damages.
|
164 |
-
|
165 |
-
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
-
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
-
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
-
or other liability obligations and/or rights consistent with this
|
169 |
-
License. However, in accepting such obligations, You may act only
|
170 |
-
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
-
of any other Contributor, and only if You agree to indemnify,
|
172 |
-
defend, and hold each Contributor harmless for any liability
|
173 |
-
incurred by, or claims asserted against, such Contributor by reason
|
174 |
-
of your accepting any such warranty or additional liability.
|
175 |
-
|
176 |
-
END OF TERMS AND CONDITIONS
|
177 |
-
|
178 |
-
APPENDIX: How to apply the Apache License to your work.
|
179 |
-
|
180 |
-
To apply the Apache License to your work, attach the following
|
181 |
-
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
-
replaced with your own identifying information. (Don't include
|
183 |
-
the brackets!) The text should be enclosed in the appropriate
|
184 |
-
comment syntax for the file format. We also recommend that a
|
185 |
-
file or class name and description of purpose be included on the
|
186 |
-
same "printed page" as the copyright notice for easier
|
187 |
-
identification within third-party archives.
|
188 |
-
|
189 |
-
Copyright [yyyy] [name of copyright owner]
|
190 |
-
|
191 |
-
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
-
you may not use this file except in compliance with the License.
|
193 |
-
You may obtain a copy of the License at
|
194 |
-
|
195 |
-
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
-
|
197 |
-
Unless required by applicable law or agreed to in writing, software
|
198 |
-
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
-
See the License for the specific language governing permissions and
|
201 |
-
limitations under the License.
|
|
|
1 |
+
FREE for Non Commercial USE
|
|
|
|
|
2 |
|
3 |
+
You are free to:
|
4 |
+
- Share — copy and redistribute the material in any medium or format
|
5 |
+
- Adapt — remix, transform, and build upon the material
|
6 |
+
The licensor cannot revoke these freedoms as long as you follow the license terms.
|
7 |
|
8 |
+
Under the following terms:
|
9 |
+
- Attribution — You must give appropriate credit , provide a link to the license, and indicate if changes were made . You may do so in any reasonable manner, but not in any way that suggests the licensor endorses you or your use.
|
10 |
+
NonCommercial — You may not use the material for commercial purposes .
|
11 |
|
12 |
+
- No additional restrictions — You may not apply legal terms or technological measures that legally restrict others from doing anything the license permits.
|
13 |
+
Notices:
|
14 |
|
15 |
+
- You do not have to comply with the license for elements of the material in the public domain or where your use is permitted by an applicable exception or limitation .
|
|
|
16 |
|
17 |
+
No warranties are given. The license may not give you all of the permissions necessary for your intended use. For example, other rights such as publicity, privacy, or moral rights may limit how you use the material.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
@@ -27,378 +27,159 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
|
|
27 |
|
28 |
## 🔥 Latest News!!
|
29 |
|
|
|
30 |
* Feb 25, 2025: 👋 We've released the inference code and weights of Wan2.1.
|
31 |
* Feb 27, 2025: 👋 Wan2.1 has been integrated into [ComfyUI](https://comfyanonymous.github.io/ComfyUI_examples/wan/). Enjoy!
|
32 |
|
33 |
|
34 |
-
##
|
35 |
-
|
36 |
-
- [x] Multi-GPU Inference code of the 14B and 1.3B models
|
37 |
-
- [x] Checkpoints of the 14B and 1.3B models
|
38 |
-
- [x] Gradio demo
|
39 |
-
- [x] ComfyUI integration
|
40 |
-
- [ ] Diffusers integration
|
41 |
-
- Wan2.1 Image-to-Video
|
42 |
-
- [x] Multi-GPU Inference code of the 14B model
|
43 |
-
- [x] Checkpoints of the 14B model
|
44 |
-
- [x] Gradio demo
|
45 |
-
- [X] ComfyUI integration
|
46 |
-
- [ ] Diffusers integration
|
47 |
-
|
48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
-
## Quickstart
|
51 |
|
52 |
-
#### Installation
|
53 |
-
Clone the repo:
|
54 |
-
```
|
55 |
-
git clone https://github.com/Wan-Video/Wan2.1.git
|
56 |
-
cd Wan2.1
|
57 |
-
```
|
58 |
-
|
59 |
-
Install dependencies:
|
60 |
-
```
|
61 |
-
# Ensure torch >= 2.4.0
|
62 |
-
pip install -r requirements.txt
|
63 |
-
```
|
64 |
-
|
65 |
-
|
66 |
-
#### Model Download
|
67 |
-
|
68 |
-
| Models | Download Link | Notes |
|
69 |
-
| --------------|-------------------------------------------------------------------------------|-------------------------------|
|
70 |
-
| T2V-14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B) | Supports both 480P and 720P
|
71 |
-
| I2V-14B-720P | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P) | Supports 720P
|
72 |
-
| I2V-14B-480P | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-480P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P) | Supports 480P
|
73 |
-
| T2V-1.3B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) | Supports 480P
|
74 |
-
|
75 |
-
> 💡Note: The 1.3B model is capable of generating videos at 720P resolution. However, due to limited training at this resolution, the results are generally less stable compared to 480P. For optimal performance, we recommend using 480P resolution.
|
76 |
-
|
77 |
-
|
78 |
-
Download models using huggingface-cli:
|
79 |
-
```
|
80 |
-
pip install "huggingface_hub[cli]"
|
81 |
-
huggingface-cli download Wan-AI/Wan2.1-T2V-14B --local-dir ./Wan2.1-T2V-14B
|
82 |
-
```
|
83 |
-
|
84 |
-
Download models using modelscope-cli:
|
85 |
-
```
|
86 |
-
pip install modelscope
|
87 |
-
modelscope download Wan-AI/Wan2.1-T2V-14B --local_dir ./Wan2.1-T2V-14B
|
88 |
-
```
|
89 |
-
#### Run Text-to-Video Generation
|
90 |
-
|
91 |
-
This repository supports two Text-to-Video models (1.3B and 14B) and two resolutions (480P and 720P). The parameters and configurations for these models are as follows:
|
92 |
-
|
93 |
-
<table>
|
94 |
-
<thead>
|
95 |
-
<tr>
|
96 |
-
<th rowspan="2">Task</th>
|
97 |
-
<th colspan="2">Resolution</th>
|
98 |
-
<th rowspan="2">Model</th>
|
99 |
-
</tr>
|
100 |
-
<tr>
|
101 |
-
<th>480P</th>
|
102 |
-
<th>720P</th>
|
103 |
-
</tr>
|
104 |
-
</thead>
|
105 |
-
<tbody>
|
106 |
-
<tr>
|
107 |
-
<td>t2v-14B</td>
|
108 |
-
<td style="color: green;">✔️</td>
|
109 |
-
<td style="color: green;">✔️</td>
|
110 |
-
<td>Wan2.1-T2V-14B</td>
|
111 |
-
</tr>
|
112 |
-
<tr>
|
113 |
-
<td>t2v-1.3B</td>
|
114 |
-
<td style="color: green;">✔️</td>
|
115 |
-
<td style="color: red;">❌</td>
|
116 |
-
<td>Wan2.1-T2V-1.3B</td>
|
117 |
-
</tr>
|
118 |
-
</tbody>
|
119 |
-
</table>
|
120 |
-
|
121 |
-
|
122 |
-
##### (1) Without Prompt Extention
|
123 |
-
|
124 |
-
To facilitate implementation, we will start with a basic version of the inference process that skips the [prompt extension](#2-using-prompt-extention) step.
|
125 |
-
|
126 |
-
- Single-GPU inference
|
127 |
-
|
128 |
-
```
|
129 |
-
python generate.py --task t2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-T2V-14B --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
|
130 |
-
```
|
131 |
-
|
132 |
-
If you encounter OOM (Out-of-Memory) issues, you can use the `--offload_model True` and `--t5_cpu` options to reduce GPU memory usage. For example, on an RTX 4090 GPU:
|
133 |
-
|
134 |
-
```
|
135 |
-
python generate.py --task t2v-1.3B --size 832*480 --ckpt_dir ./Wan2.1-T2V-1.3B --offload_model True --t5_cpu --sample_shift 8 --sample_guide_scale 6 --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
|
136 |
-
```
|
137 |
|
138 |
-
|
139 |
|
|
|
140 |
|
141 |
-
|
142 |
|
143 |
-
|
144 |
-
|
145 |
-
torchrun --nproc_per_node=8 generate.py --task t2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-T2V-14B --dit_fsdp --t5_fsdp --ulysses_size 8 --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
|
146 |
-
```
|
147 |
|
148 |
|
149 |
-
|
150 |
|
151 |
-
|
|
|
152 |
|
153 |
-
|
154 |
-
- Apply for a `dashscope.api_key` in advance ([EN](https://www.alibabacloud.com/help/en/model-studio/getting-started/first-api-call-to-qwen) | [CN](https://help.aliyun.com/zh/model-studio/getting-started/first-api-call-to-qwen)).
|
155 |
-
- Configure the environment variable `DASH_API_KEY` to specify the Dashscope API key. For users of Alibaba Cloud's international site, you also need to set the environment variable `DASH_API_URL` to 'https://dashscope-intl.aliyuncs.com/api/v1'. For more detailed instructions, please refer to the [dashscope document](https://www.alibabacloud.com/help/en/model-studio/developer-reference/use-qwen-by-calling-api?spm=a2c63.p38356.0.i1).
|
156 |
-
- Use the `qwen-plus` model for text-to-video tasks and `qwen-vl-max` for image-to-video tasks.
|
157 |
-
- You can modify the model used for extension with the parameter `--prompt_extend_model`. For example:
|
158 |
-
```
|
159 |
-
DASH_API_KEY=your_key python generate.py --task t2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-T2V-14B --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage" --use_prompt_extend --prompt_extend_method 'dashscope' --prompt_extend_target_lang 'ch'
|
160 |
-
```
|
161 |
|
162 |
-
|
|
|
|
|
|
|
163 |
|
164 |
-
|
165 |
-
- For text-to-video tasks, you can use models like `Qwen/Qwen2.5-14B-Instruct`, `Qwen/Qwen2.5-7B-Instruct` and `Qwen/Qwen2.5-3B-Instruct`.
|
166 |
-
- For image-to-video tasks, you can use models like `Qwen/Qwen2.5-VL-7B-Instruct` and `Qwen/Qwen2.5-VL-3B-Instruct`.
|
167 |
-
- Larger models generally provide better extension results but require more GPU memory.
|
168 |
-
- You can modify the model used for extension with the parameter `--prompt_extend_model` , allowing you to specify either a local model path or a Hugging Face model. For example:
|
169 |
|
170 |
-
|
171 |
-
|
172 |
-
```
|
173 |
|
174 |
-
##### (3) Runing local gradio
|
175 |
|
176 |
-
|
177 |
-
|
178 |
-
# if one uses dashscope’s API for prompt extension
|
179 |
-
DASH_API_KEY=your_key python t2v_14B_singleGPU.py --prompt_extend_method 'dashscope' --ckpt_dir ./Wan2.1-T2V-14B
|
180 |
|
181 |
-
#
|
182 |
-
python
|
183 |
-
```
|
184 |
|
|
|
|
|
|
|
|
|
185 |
|
186 |
-
|
187 |
-
|
188 |
-
Similar to Text-to-Video, Image-to-Video is also divided into processes with and without the prompt extension step. The specific parameters and their corresponding settings are as follows:
|
189 |
-
<table>
|
190 |
-
<thead>
|
191 |
-
<tr>
|
192 |
-
<th rowspan="2">Task</th>
|
193 |
-
<th colspan="2">Resolution</th>
|
194 |
-
<th rowspan="2">Model</th>
|
195 |
-
</tr>
|
196 |
-
<tr>
|
197 |
-
<th>480P</th>
|
198 |
-
<th>720P</th>
|
199 |
-
</tr>
|
200 |
-
</thead>
|
201 |
-
<tbody>
|
202 |
-
<tr>
|
203 |
-
<td>i2v-14B</td>
|
204 |
-
<td style="color: green;">❌</td>
|
205 |
-
<td style="color: green;">✔️</td>
|
206 |
-
<td>Wan2.1-I2V-14B-720P</td>
|
207 |
-
</tr>
|
208 |
-
<tr>
|
209 |
-
<td>i2v-14B</td>
|
210 |
-
<td style="color: green;">✔️</td>
|
211 |
-
<td style="color: red;">❌</td>
|
212 |
-
<td>Wan2.1-T2V-14B-480P</td>
|
213 |
-
</tr>
|
214 |
-
</tbody>
|
215 |
-
</table>
|
216 |
-
|
217 |
-
|
218 |
-
##### (1) Without Prompt Extention
|
219 |
-
|
220 |
-
- Single-GPU inference
|
221 |
-
```
|
222 |
-
python generate.py --task i2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-I2V-14B-720P --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
|
223 |
-
```
|
224 |
|
225 |
-
> 💡For the Image-to-Video task, the `size` parameter represents the area of the generated video, with the aspect ratio following that of the original input image.
|
226 |
|
227 |
|
228 |
-
- Multi-GPU inference using FSDP + xDiT USP
|
229 |
|
230 |
```
|
231 |
-
pip install "xfuser>=0.4.1"
|
232 |
-
torchrun --nproc_per_node=8 generate.py --task i2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-I2V-14B-720P --image examples/i2v_input.JPG --dit_fsdp --t5_fsdp --ulysses_size 8 --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
|
233 |
-
```
|
234 |
-
|
235 |
-
##### (2) Using Prompt Extention
|
236 |
-
|
237 |
|
238 |
-
|
|
|
239 |
|
240 |
-
|
|
|
|
|
241 |
```
|
242 |
-
|
243 |
```
|
244 |
|
245 |
-
|
246 |
```
|
247 |
-
|
|
|
248 |
```
|
249 |
|
250 |
-
|
251 |
|
|
|
|
|
|
|
252 |
```
|
253 |
-
cd gradio
|
254 |
-
# if one only uses 480P model in gradio
|
255 |
-
DASH_API_KEY=your_key python i2v_14B_singleGPU.py --prompt_extend_method 'dashscope' --ckpt_dir_480p ./Wan2.1-I2V-14B-480P
|
256 |
|
257 |
-
# if one only uses 720P model in gradio
|
258 |
-
DASH_API_KEY=your_key python i2v_14B_singleGPU.py --prompt_extend_method 'dashscope' --ckpt_dir_720p ./Wan2.1-I2V-14B-720P
|
259 |
|
260 |
-
|
261 |
-
DASH_API_KEY=your_key python i2v_14B_singleGPU.py --prompt_extend_method 'dashscope' --ckpt_dir_480p ./Wan2.1-I2V-14B-480P --ckpt_dir_720p ./Wan2.1-I2V-14B-720P
|
262 |
-
```
|
263 |
-
|
264 |
-
|
265 |
-
#### Run Text-to-Image Generation
|
266 |
|
267 |
-
|
268 |
|
269 |
-
|
270 |
|
271 |
-
-
|
272 |
-
```
|
273 |
-
python generate.py --task t2i-14B --size 1024*1024 --ckpt_dir ./Wan2.1-T2V-14B --prompt '一个朴素端庄的美人'
|
274 |
-
```
|
275 |
|
276 |
-
|
277 |
|
278 |
-
|
279 |
-
|
|
|
280 |
```
|
281 |
|
282 |
-
|
283 |
|
284 |
-
|
285 |
-
```
|
286 |
-
python generate.py --task t2i-14B --size 1024*1024 --ckpt_dir ./Wan2.1-T2V-14B --prompt '一个朴素端庄的美人' --use_prompt_extend
|
287 |
-
```
|
288 |
|
289 |
-
- Multi-GPU inference using FSDP + xDiT USP
|
290 |
-
```
|
291 |
-
torchrun --nproc_per_node=8 generate.py --dit_fsdp --t5_fsdp --ulysses_size 8 --base_seed 0 --frame_num 1 --task t2i-14B --size 1024*1024 --ckpt_dir ./Wan2.1-T2V-14B --prompt '一个朴素端庄的美人' --use_prompt_extend
|
292 |
-
```
|
293 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
294 |
|
295 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
296 |
|
297 |
-
|
|
|
298 |
|
299 |
-
|
300 |
|
301 |
-
|
302 |
-
|
303 |
-
</div>
|
304 |
-
|
305 |
-
|
306 |
-
##### (2) Image-to-Video Evaluation
|
307 |
-
|
308 |
-
We also conducted extensive manual evaluations to evaluate the performance of the Image-to-Video model, and the results are presented in the table below. The results clearly indicate that **Wan2.1** outperforms both closed-source and open-source models.
|
309 |
-
|
310 |
-
<div align="center">
|
311 |
-
<img src="assets/i2v_res.png" alt="" style="width: 80%;" />
|
312 |
-
</div>
|
313 |
-
|
314 |
-
|
315 |
-
## Computational Efficiency on Different GPUs
|
316 |
-
|
317 |
-
We test the computational efficiency of different **Wan2.1** models on different GPUs in the following table. The results are presented in the format: **Total time (s) / peak GPU memory (GB)**.
|
318 |
-
|
319 |
-
|
320 |
-
<div align="center">
|
321 |
-
<img src="assets/comp_effic.png" alt="" style="width: 80%;" />
|
322 |
-
</div>
|
323 |
-
|
324 |
-
> The parameter settings for the tests presented in this table are as follows:
|
325 |
-
> (1) For the 1.3B model on 8 GPUs, set `--ring_size 8` and `--ulysses_size 1`;
|
326 |
-
> (2) For the 14B model on 1 GPU, use `--offload_model True`;
|
327 |
-
> (3) For the 1.3B model on a single 4090 GPU, set `--offload_model True --t5_cpu`;
|
328 |
-
> (4) For all testings, no prompt extension was applied, meaning `--use_prompt_extend` was not enabled.
|
329 |
-
|
330 |
-
> 💡Note: T2V-14B is slower than I2V-14B because the former samples 50 steps while the latter uses 40 steps.
|
331 |
-
|
332 |
-
|
333 |
-
## Community Contributions
|
334 |
-
- [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) provides more support for **Wan2.1**, including video-to-video, FP8 quantization, VRAM optimization, LoRA training, and more. Please refer to [their examples](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo).
|
335 |
-
|
336 |
-
-------
|
337 |
-
|
338 |
-
## Introduction of Wan2.1
|
339 |
-
|
340 |
-
**Wan2.1** is designed on the mainstream diffusion transformer paradigm, achieving significant advancements in generative capabilities through a series of innovations. These include our novel spatio-temporal variational autoencoder (VAE), scalable training strategies, large-scale data construction, and automated evaluation metrics. Collectively, these contributions enhance the model’s performance and versatility.
|
341 |
-
|
342 |
-
|
343 |
-
##### (1) 3D Variational Autoencoders
|
344 |
-
We propose a novel 3D causal VAE architecture, termed **Wan-VAE** specifically designed for video generation. By combining multiple strategies, we improve spatio-temporal compression, reduce memory usage, and ensure temporal causality. **Wan-VAE** demonstrates significant advantages in performance efficiency compared to other open-source VAEs. Furthermore, our **Wan-VAE** can encode and decode unlimited-length 1080P videos without losing historical temporal information, making it particularly well-suited for video generation tasks.
|
345 |
-
|
346 |
-
|
347 |
-
<div align="center">
|
348 |
-
<img src="assets/video_vae_res.jpg" alt="" style="width: 80%;" />
|
349 |
-
</div>
|
350 |
-
|
351 |
-
|
352 |
-
##### (2) Video Diffusion DiT
|
353 |
-
|
354 |
-
**Wan2.1** is designed using the Flow Matching framework within the paradigm of mainstream Diffusion Transformers. Our model's architecture uses the T5 Encoder to encode multilingual text input, with cross-attention in each transformer block embedding the text into the model structure. Additionally, we employ an MLP with a Linear layer and a SiLU layer to process the input time embeddings and predict six modulation parameters individually. This MLP is shared across all transformer blocks, with each block learning a distinct set of biases. Our experimental findings reveal a significant performance improvement with this approach at the same parameter scale.
|
355 |
-
|
356 |
-
<div align="center">
|
357 |
-
<img src="assets/video_dit_arch.jpg" alt="" style="width: 80%;" />
|
358 |
-
</div>
|
359 |
-
|
360 |
-
|
361 |
-
| Model | Dimension | Input Dimension | Output Dimension | Feedforward Dimension | Frequency Dimension | Number of Heads | Number of Layers |
|
362 |
-
|--------|-----------|-----------------|------------------|-----------------------|---------------------|-----------------|------------------|
|
363 |
-
| 1.3B | 1536 | 16 | 16 | 8960 | 256 | 12 | 30 |
|
364 |
-
| 14B | 5120 | 16 | 16 | 13824 | 256 | 40 | 40 |
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
##### Data
|
369 |
-
|
370 |
-
We curated and deduplicated a candidate dataset comprising a vast amount of image and video data. During the data curation process, we designed a four-step data cleaning process, focusing on fundamental dimensions, visual quality and motion quality. Through the robust data processing pipeline, we can easily obtain high-quality, diverse, and large-scale training sets of images and videos.
|
371 |
-
|
372 |
-

|
373 |
-
|
374 |
-
|
375 |
-
##### Comparisons to SOTA
|
376 |
-
We compared **Wan2.1** with leading open-source and closed-source models to evaluate the performace. Using our carefully designed set of 1,035 internal prompts, we tested across 14 major dimensions and 26 sub-dimensions. We then compute the total score by performing a weighted calculation on the scores of each dimension, utilizing weights derived from human preferences in the matching process. The detailed results are shown in the table below. These results demonstrate our model's superior performance compared to both open-source and closed-source models.
|
377 |
-
|
378 |
-

|
379 |
-
|
380 |
-
|
381 |
-
## Citation
|
382 |
-
If you find our work helpful, please cite us.
|
383 |
-
|
384 |
-
```
|
385 |
-
@article{wan2.1,
|
386 |
-
title = {Wan: Open and Advanced Large-Scale Video Generative Models},
|
387 |
-
author = {Wan Team},
|
388 |
-
journal = {},
|
389 |
-
year = {2025}
|
390 |
-
}
|
391 |
-
```
|
392 |
|
393 |
-
|
394 |
-
|
395 |
|
|
|
|
|
396 |
|
397 |
-
|
|
|
398 |
|
399 |
-
|
|
|
400 |
|
|
|
|
|
401 |
|
402 |
|
403 |
-
## Contact Us
|
404 |
-
If you would like to leave a message to our research or product teams, feel free to join our [Discord](https://discord.gg/p5XbdQV7) or [WeChat groups](https://gw.alicdn.com/imgextra/i2/O1CN01tqjWFi1ByuyehkTSB_!!6000000000015-0-tps-611-1279.jpg)!
|
|
|
27 |
|
28 |
## 🔥 Latest News!!
|
29 |
|
30 |
+
* Mar 03, 2025: Wan2.1GP DeepBeepMeep out of this World version ! Reduced memory consumption by 2, with possiblity to generate more than 10s of video at 720p
|
31 |
* Feb 25, 2025: 👋 We've released the inference code and weights of Wan2.1.
|
32 |
* Feb 27, 2025: 👋 Wan2.1 has been integrated into [ComfyUI](https://comfyanonymous.github.io/ComfyUI_examples/wan/). Enjoy!
|
33 |
|
34 |
|
35 |
+
## Features
|
36 |
+
*GPU Poor version by **DeepBeepMeep**. This great video generator can now run smoothly on any GPU.*
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
+
This version has the following improvements over the original Hunyuan Video model:
|
39 |
+
- Reduce greatly the RAM requirements and VRAM requirements
|
40 |
+
- Much faster thanks to compilation and fast loading / unloading
|
41 |
+
- 5 profiles in order to able to run the model at a decent speed on a low end consumer config (32 GB of RAM and 12 VRAM) and to run it at a very good speed on a high end consumer config (48 GB of RAM and 24 GB of VRAM)
|
42 |
+
- Autodownloading of the needed model files
|
43 |
+
- Improved gradio interface with progression bar and more options
|
44 |
+
- Multiples prompts / multiple generations per prompt
|
45 |
+
- Support multiple pretrained Loras with 32 GB of RAM or less
|
46 |
+
- Switch easily between Hunyuan and Fast Hunyuan models and quantized / non quantized models
|
47 |
+
- Much simpler installation
|
48 |
|
|
|
49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
+
This fork by DeepBeepMeep is an integration of the mmpg module on the gradio_server.py.
|
52 |
|
53 |
+
It is an illustration on how one can set up on an existing model some fast and properly working CPU offloading with changing only a few lines of code in the core model.
|
54 |
|
55 |
+
For more information on how to use the mmpg module, please go to: https://github.com/deepbeepmeep/mmgp
|
56 |
|
57 |
+
You will find the original Hunyuan Video repository here: https://github.com/deepbeepmeep/Wan2GP
|
58 |
+
|
|
|
|
|
59 |
|
60 |
|
61 |
+
## Installation Guide for Linux and Windows
|
62 |
|
63 |
+
We provide an `environment.yml` file for setting up a Conda environment.
|
64 |
+
Conda's installation instructions are available [here](https://docs.anaconda.com/free/miniconda/index.html).
|
65 |
|
66 |
+
This app has been tested on Python 3.10 / 2.6.0 / Cuda 12.4.\
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
+
```shell
|
69 |
+
# 1 - conda. Prepare and activate a conda environment
|
70 |
+
conda env create -f environment.yml
|
71 |
+
conda activate Wan2
|
72 |
|
73 |
+
# OR
|
|
|
|
|
|
|
|
|
74 |
|
75 |
+
# 1 - venv. Alternatively create a python 3.10 venv and then do the following
|
76 |
+
pip install torch==2.6.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu124
|
|
|
77 |
|
|
|
78 |
|
79 |
+
# 2. Install pip dependencies
|
80 |
+
python -m pip install -r requirements.txt
|
|
|
|
|
81 |
|
82 |
+
# 3.1 optional Sage attention support (30% faster, easy to install on Linux but much harder on Windows)
|
83 |
+
python -m pip install sageattention==1.0.6
|
|
|
84 |
|
85 |
+
# or for Sage Attention 2 (40% faster, sorry only manual compilation for the moment)
|
86 |
+
git pull https://github.com/thu-ml/SageAttention
|
87 |
+
cd sageattention
|
88 |
+
pip install -e .
|
89 |
|
90 |
+
# 3.2 optional Flash attention support (easy to install on Linux but much harder on Windows)
|
91 |
+
python -m pip install flash-attn==2.7.2.post1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
|
|
|
93 |
|
94 |
|
|
|
95 |
|
96 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
+
Note that *Flash attention* and *Sage attention* are quite complex to install on Windows but offers a better memory management (and consequently longer videos) than the default *sdpa attention*.
|
99 |
+
Likewise *Pytorch Compilation* will work on Windows only if you manage to install Triton. It is quite a complex process (see below for links).
|
100 |
|
101 |
+
### Ready to use python wheels for Windows users
|
102 |
+
I provide here links to simplify the installation for Windows users with Python 3.10 / Pytorch 2.51 / Cuda 12.4. As I am not hosting these files I won't be able to provide support neither guarantee they do what they should do.
|
103 |
+
- Triton attention (needed for *pytorch compilation* and *Sage attention*)
|
104 |
```
|
105 |
+
pip install https://github.com/woct0rdho/triton-windows/releases/download/v3.2.0-windows.post9/triton-3.2.0-cp310-cp310-win_amd64.whl # triton for pytorch 2.6.0
|
106 |
```
|
107 |
|
108 |
+
- Sage attention
|
109 |
```
|
110 |
+
pip install https://github.com/deepbeepmeep/SageAttention/raw/refs/heads/main/releases/sageattention-2.1.0-cp310-cp310-win_amd64.whl # for pytorch 2.6.0 (experimental, if it works, otherwise you you will need to install and compile manually, see above)
|
111 |
+
|
112 |
```
|
113 |
|
114 |
+
## Run the application
|
115 |
|
116 |
+
### Run a Gradio Server on port 7860 (recommended)
|
117 |
+
```bash
|
118 |
+
python gradio_server.py
|
119 |
```
|
|
|
|
|
|
|
120 |
|
|
|
|
|
121 |
|
122 |
+
### Loras support
|
|
|
|
|
|
|
|
|
|
|
123 |
|
124 |
+
-- Ready to be used but theorical as no lora for Wan have been released as today.
|
125 |
|
126 |
+
Every lora stored in the subfoler 'loras' will be automatically loaded. You will be then able to activate / desactive any of them when running the application.
|
127 |
|
128 |
+
For each activated Lora, you may specify a *multiplier* that is one float number that corresponds to its weight (default is 1.0), alternatively you may specify a list of floats multipliers separated by a "," that gives the evolution of this Lora's multiplier over the steps. For instance let's assume there are 30 denoising steps and the multiplier is *0.9,0.8,0.7* then for the steps ranges 0-9, 10-19 and 20-29 the Lora multiplier will be respectively 0.9, 0.8 and 0.7.
|
|
|
|
|
|
|
129 |
|
130 |
+
You can edit, save or delete Loras presets (combinations of loras with their corresponding multipliers) directly from the gradio interface. Each preset, is a file with ".lset" extension stored in the loras directory and can be shared with other users
|
131 |
|
132 |
+
Then you can pre activate loras corresponding to a preset when launching the gradio server:
|
133 |
+
```bash
|
134 |
+
python gradio_server.py --lora-preset mylorapreset.lset # where 'mylorapreset.lset' is a preset stored in the 'loras' folder
|
135 |
```
|
136 |
|
137 |
+
Please note that command line parameters *--lora-weight* and *--lora-multiplier* have been deprecated since they are redundant with presets.
|
138 |
|
139 |
+
You will find prebuilt Loras on https://civitai.com/ or you will be able to build them with tools such as kohya or onetrainer.
|
|
|
|
|
|
|
140 |
|
|
|
|
|
|
|
|
|
141 |
|
142 |
+
### Command line parameters for Gradio Server
|
143 |
+
--profile no : default (4) : no of profile between 1 and 5\
|
144 |
+
--quantize-transformer bool: (default True) : enable / disable on the fly transformer quantization\
|
145 |
+
--lora-dir path : Path of directory that contains Loras in diffusers / safetensor format\
|
146 |
+
--lora-preset preset : name of preset gile (without the extension) to preload
|
147 |
+
--verbose level : default (1) : level of information between 0 and 2\
|
148 |
+
--server-port portno : default (7860) : Gradio port no\
|
149 |
+
--server-name name : default (0.0.0.0) : Gradio server name\
|
150 |
+
--open-browser : open automatically Browser when launching Gradio Server\
|
151 |
+
--compile : turn on pytorch compilation\
|
152 |
+
--attention mode: force attention mode among, sdpa, flash, sage, sage2\
|
153 |
|
154 |
+
### Profiles (for power users only)
|
155 |
+
You can choose between 5 profiles, these will try to leverage the most your hardware, but have little impact for HunyuanVideo GP:
|
156 |
+
- HighRAM_HighVRAM (1): the fastest well suited for a RTX 3090 / RTX 4090 but consumes much more VRAM, adapted for fast shorter video
|
157 |
+
- HighRAM_LowVRAM (2): a bit slower, better suited for RTX 3070/3080/4070/4080 or for RTX 3090 / RTX 4090 with large pictures batches or long videos
|
158 |
+
- LowRAM_HighVRAM (3): adapted for RTX 3090 / RTX 4090 with limited RAM but at the cost of VRAM (shorter videos)
|
159 |
+
- LowRAM_LowVRAM (4): if you have little VRAM or want to generate longer videos
|
160 |
+
- VerylowRAM_LowVRAM (5): at least 24 GB of RAM and 10 GB of VRAM : if you don't have much it won't be fast but maybe it will work
|
161 |
|
162 |
+
Profile 2 (High RAM) and 4 (Low RAM)are the most recommended profiles since they are versatile (support for long videos for a slight performance cost).\
|
163 |
+
However, a safe approach is to start from profile 5 (default profile) and then go down progressively to profile 4 and then to profile 2 as long as the app remains responsive or doesn't trigger any out of memory error.
|
164 |
|
165 |
+
### Other Models for the GPU Poor
|
166 |
|
167 |
+
- HuanyuanVideoGP: https://github.com/deepbeepmeep/HunyuanVideoGP :\
|
168 |
+
One of the best open source Text to Video generator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
|
170 |
+
- Hunyuan3D-2GP: https://github.com/deepbeepmeep/Hunyuan3D-2GP :\
|
171 |
+
A great image to 3D and text to 3D tool by the Tencent team. Thanks to mmgp it can run with less than 6 GB of VRAM
|
172 |
|
173 |
+
- FluxFillGP: https://github.com/deepbeepmeep/FluxFillGP :\
|
174 |
+
One of the best inpainting / outpainting tools based on Flux that can run with less than 12 GB of VRAM.
|
175 |
|
176 |
+
- Cosmos1GP: https://github.com/deepbeepmeep/Cosmos1GP :\
|
177 |
+
This application include two models: a text to world generator and a image / video to world (probably the best open source image to video generator).
|
178 |
|
179 |
+
- OminiControlGP: https://github.com/deepbeepmeep/OminiControlGP :\
|
180 |
+
A Flux derived application very powerful that can be used to transfer an object of your choice in a prompted scene. With mmgp you can run it with only 6 GB of VRAM.
|
181 |
|
182 |
+
- YuE GP: https://github.com/deepbeepmeep/YuEGP :\
|
183 |
+
A great song generator (instruments + singer's voice) based on prompted Lyrics and a genre description. Thanks to mmgp you can run it with less than 10 GB of VRAM without waiting forever.
|
184 |
|
185 |
|
|
|
|
gradio/i2v_14B_singleGPU.py
CHANGED
@@ -24,8 +24,9 @@ wan_i2v_720P = None
|
|
24 |
|
25 |
|
26 |
# Button Func
|
27 |
-
def
|
28 |
global wan_i2v_480P, wan_i2v_720P
|
|
|
29 |
|
30 |
if value == '------':
|
31 |
print("No model loaded")
|
@@ -52,8 +53,11 @@ def load_model(value):
|
|
52 |
t5_fsdp=False,
|
53 |
dit_fsdp=False,
|
54 |
use_usp=False,
|
55 |
-
|
|
|
56 |
print("done", flush=True)
|
|
|
|
|
57 |
return '720P'
|
58 |
|
59 |
if value == '480P':
|
@@ -77,11 +81,16 @@ def load_model(value):
|
|
77 |
t5_fsdp=False,
|
78 |
dit_fsdp=False,
|
79 |
use_usp=False,
|
|
|
80 |
)
|
81 |
print("done", flush=True)
|
|
|
|
|
|
|
82 |
return '480P'
|
83 |
|
84 |
|
|
|
85 |
def prompt_enc(prompt, img, tar_lang):
|
86 |
print('prompt extend...')
|
87 |
if img is None:
|
@@ -96,10 +105,12 @@ def prompt_enc(prompt, img, tar_lang):
|
|
96 |
return prompt_output.prompt
|
97 |
|
98 |
|
99 |
-
def i2v_generation(img2vid_prompt, img2vid_image,
|
100 |
guide_scale, shift_scale, seed, n_prompt):
|
101 |
# print(f"{img2vid_prompt},{resolution},{sd_steps},{guide_scale},{shift_scale},{seed},{n_prompt}")
|
102 |
-
|
|
|
|
|
103 |
if resolution == '------':
|
104 |
print(
|
105 |
'Please specify at least one resolution ckpt dir or specify the resolution'
|
@@ -118,19 +129,19 @@ def i2v_generation(img2vid_prompt, img2vid_image, resolution, sd_steps,
|
|
118 |
guide_scale=guide_scale,
|
119 |
n_prompt=n_prompt,
|
120 |
seed=seed,
|
121 |
-
offload_model=
|
122 |
else:
|
123 |
global wan_i2v_480P
|
124 |
video = wan_i2v_480P.generate(
|
125 |
img2vid_prompt,
|
126 |
img2vid_image,
|
127 |
max_area=MAX_AREA_CONFIGS['480*832'],
|
128 |
-
shift=shift_scale
|
129 |
sampling_steps=sd_steps,
|
130 |
guide_scale=guide_scale,
|
131 |
n_prompt=n_prompt,
|
132 |
seed=seed,
|
133 |
-
offload_model=
|
134 |
|
135 |
cache_video(
|
136 |
tensor=video[None],
|
@@ -169,6 +180,7 @@ def gradio_interface():
|
|
169 |
)
|
170 |
img2vid_prompt = gr.Textbox(
|
171 |
label="Prompt",
|
|
|
172 |
placeholder="Describe the video you want to generate",
|
173 |
)
|
174 |
tar_lang = gr.Radio(
|
@@ -262,6 +274,8 @@ def _parse_args():
|
|
262 |
help="The prompt extend model to use.")
|
263 |
|
264 |
args = parser.parse_args()
|
|
|
|
|
265 |
assert args.ckpt_dir_720p is not None or args.ckpt_dir_480p is not None, "Please specify at least one checkpoint directory."
|
266 |
|
267 |
return args
|
@@ -269,6 +283,12 @@ def _parse_args():
|
|
269 |
|
270 |
if __name__ == '__main__':
|
271 |
args = _parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
|
272 |
|
273 |
print("Step1: Init prompt_expander...", end='', flush=True)
|
274 |
if args.prompt_extend_method == "dashscope":
|
|
|
24 |
|
25 |
|
26 |
# Button Func
|
27 |
+
def load_i2v_model(value):
|
28 |
global wan_i2v_480P, wan_i2v_720P
|
29 |
+
from mmgp import offload
|
30 |
|
31 |
if value == '------':
|
32 |
print("No model loaded")
|
|
|
53 |
t5_fsdp=False,
|
54 |
dit_fsdp=False,
|
55 |
use_usp=False,
|
56 |
+
i2v720p= True
|
57 |
+
)
|
58 |
print("done", flush=True)
|
59 |
+
pipe = {"transformer": wan_i2v_720P.model, "text_encoder" : wan_i2v_720P.text_encoder.model, "text_encoder_2": wan_i2v_720P.clip.model, "vae": wan_i2v_720P.vae.model } #
|
60 |
+
offload.profile(pipe, profile_no=4, budgets = {"transformer":100, "*":3000}, verboseLevel=2, compile="transformer", quantizeTransformer = False, pinnedMemory = False)
|
61 |
return '720P'
|
62 |
|
63 |
if value == '480P':
|
|
|
81 |
t5_fsdp=False,
|
82 |
dit_fsdp=False,
|
83 |
use_usp=False,
|
84 |
+
i2v720p= False
|
85 |
)
|
86 |
print("done", flush=True)
|
87 |
+
pipe = {"transformer": wan_i2v_480P.model, "text_encoder" : wan_i2v_480P.text_encoder.model, "text_encoder_2": wan_i2v_480P.clip.model, "vae": wan_i2v_480P.vae.model } #
|
88 |
+
offload.profile(pipe, profile_no=4, budgets = {"model":100, "*":3000}, verboseLevel=2, compile="transformer" )
|
89 |
+
|
90 |
return '480P'
|
91 |
|
92 |
|
93 |
+
|
94 |
def prompt_enc(prompt, img, tar_lang):
|
95 |
print('prompt extend...')
|
96 |
if img is None:
|
|
|
105 |
return prompt_output.prompt
|
106 |
|
107 |
|
108 |
+
def i2v_generation(img2vid_prompt, img2vid_image, res, sd_steps,
|
109 |
guide_scale, shift_scale, seed, n_prompt):
|
110 |
# print(f"{img2vid_prompt},{resolution},{sd_steps},{guide_scale},{shift_scale},{seed},{n_prompt}")
|
111 |
+
global resolution
|
112 |
+
from PIL import Image
|
113 |
+
img2vid_image = Image.open("d:\mammoth2.jpg")
|
114 |
if resolution == '------':
|
115 |
print(
|
116 |
'Please specify at least one resolution ckpt dir or specify the resolution'
|
|
|
129 |
guide_scale=guide_scale,
|
130 |
n_prompt=n_prompt,
|
131 |
seed=seed,
|
132 |
+
offload_model=False)
|
133 |
else:
|
134 |
global wan_i2v_480P
|
135 |
video = wan_i2v_480P.generate(
|
136 |
img2vid_prompt,
|
137 |
img2vid_image,
|
138 |
max_area=MAX_AREA_CONFIGS['480*832'],
|
139 |
+
shift=3.0, #shift_scale
|
140 |
sampling_steps=sd_steps,
|
141 |
guide_scale=guide_scale,
|
142 |
n_prompt=n_prompt,
|
143 |
seed=seed,
|
144 |
+
offload_model=False)
|
145 |
|
146 |
cache_video(
|
147 |
tensor=video[None],
|
|
|
180 |
)
|
181 |
img2vid_prompt = gr.Textbox(
|
182 |
label="Prompt",
|
183 |
+
value="Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field.",
|
184 |
placeholder="Describe the video you want to generate",
|
185 |
)
|
186 |
tar_lang = gr.Radio(
|
|
|
274 |
help="The prompt extend model to use.")
|
275 |
|
276 |
args = parser.parse_args()
|
277 |
+
args.ckpt_dir_720p = "../ckpts" # os.path.join("ckpt")
|
278 |
+
args.ckpt_dir_480p = "../ckpts" # os.path.join("ckpt")
|
279 |
assert args.ckpt_dir_720p is not None or args.ckpt_dir_480p is not None, "Please specify at least one checkpoint directory."
|
280 |
|
281 |
return args
|
|
|
283 |
|
284 |
if __name__ == '__main__':
|
285 |
args = _parse_args()
|
286 |
+
global resolution
|
287 |
+
# load_model('720P')
|
288 |
+
# resolution = '720P'
|
289 |
+
resolution = '480P'
|
290 |
+
|
291 |
+
load_model(resolution)
|
292 |
|
293 |
print("Step1: Init prompt_expander...", end='', flush=True)
|
294 |
if args.prompt_extend_method == "dashscope":
|
gradio/t2i_14B_singleGPU.py
CHANGED
@@ -190,6 +190,7 @@ if __name__ == '__main__':
|
|
190 |
|
191 |
print("Step2: Init 14B t2i model...", end='', flush=True)
|
192 |
cfg = WAN_CONFIGS['t2i-14B']
|
|
|
193 |
wan_t2i = wan.WanT2V(
|
194 |
config=cfg,
|
195 |
checkpoint_dir=args.ckpt_dir,
|
|
|
190 |
|
191 |
print("Step2: Init 14B t2i model...", end='', flush=True)
|
192 |
cfg = WAN_CONFIGS['t2i-14B']
|
193 |
+
# cfg = WAN_CONFIGS['t2v-1.3B']
|
194 |
wan_t2i = wan.WanT2V(
|
195 |
config=cfg,
|
196 |
checkpoint_dir=args.ckpt_dir,
|
gradio/t2v_14B_singleGPU.py
CHANGED
@@ -46,7 +46,7 @@ def t2v_generation(txt2vid_prompt, resolution, sd_steps, guide_scale,
|
|
46 |
guide_scale=guide_scale,
|
47 |
n_prompt=n_prompt,
|
48 |
seed=seed,
|
49 |
-
offload_model=
|
50 |
|
51 |
cache_video(
|
52 |
tensor=video[None],
|
@@ -177,28 +177,39 @@ if __name__ == '__main__':
|
|
177 |
args = _parse_args()
|
178 |
|
179 |
print("Step1: Init prompt_expander...", end='', flush=True)
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
|
|
|
|
|
|
190 |
|
191 |
print("Step2: Init 14B t2v model...", end='', flush=True)
|
192 |
cfg = WAN_CONFIGS['t2v-14B']
|
|
|
|
|
193 |
wan_t2v = wan.WanT2V(
|
194 |
config=cfg,
|
195 |
-
checkpoint_dir=
|
196 |
device_id=0,
|
197 |
rank=0,
|
198 |
t5_fsdp=False,
|
199 |
dit_fsdp=False,
|
200 |
use_usp=False,
|
201 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
print("done", flush=True)
|
203 |
|
204 |
demo = gradio_interface()
|
|
|
46 |
guide_scale=guide_scale,
|
47 |
n_prompt=n_prompt,
|
48 |
seed=seed,
|
49 |
+
offload_model=False)
|
50 |
|
51 |
cache_video(
|
52 |
tensor=video[None],
|
|
|
177 |
args = _parse_args()
|
178 |
|
179 |
print("Step1: Init prompt_expander...", end='', flush=True)
|
180 |
+
prompt_expander = None
|
181 |
+
# if args.prompt_extend_method == "dashscope":
|
182 |
+
# prompt_expander = DashScopePromptExpander(
|
183 |
+
# model_name=args.prompt_extend_model, is_vl=False)
|
184 |
+
# elif args.prompt_extend_method == "local_qwen":
|
185 |
+
# prompt_expander = QwenPromptExpander(
|
186 |
+
# model_name=args.prompt_extend_model, is_vl=False, device=0)
|
187 |
+
# else:
|
188 |
+
# raise NotImplementedError(
|
189 |
+
# f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
|
190 |
+
# print("done", flush=True)
|
191 |
+
|
192 |
+
from mmgp import offload
|
193 |
|
194 |
print("Step2: Init 14B t2v model...", end='', flush=True)
|
195 |
cfg = WAN_CONFIGS['t2v-14B']
|
196 |
+
# cfg = WAN_CONFIGS['t2v-1.3B']
|
197 |
+
|
198 |
wan_t2v = wan.WanT2V(
|
199 |
config=cfg,
|
200 |
+
checkpoint_dir="../ckpts",
|
201 |
device_id=0,
|
202 |
rank=0,
|
203 |
t5_fsdp=False,
|
204 |
dit_fsdp=False,
|
205 |
use_usp=False,
|
206 |
)
|
207 |
+
|
208 |
+
pipe = {"transformer": wan_t2v.model, "text_encoder" : wan_t2v.text_encoder.model, "vae": wan_t2v.vae.model } #
|
209 |
+
# offload.profile(pipe, profile_no=4, budgets = {"transformer":100, "*":3000}, verboseLevel=2, quantizeTransformer = False, compile = "transformer") #
|
210 |
+
offload.profile(pipe, profile_no=4, budgets = {"transformer":100, "*":3000}, verboseLevel=2, quantizeTransformer = False) #
|
211 |
+
# offload.profile(pipe, profile_no=4, budgets = {"transformer":3000, "*":3000}, verboseLevel=2, quantizeTransformer = False)
|
212 |
+
|
213 |
print("done", flush=True)
|
214 |
|
215 |
demo = gradio_interface()
|
gradio_server.py
ADDED
@@ -0,0 +1,1275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import argparse
|
4 |
+
from mmgp import offload, safetensors2, profile_type
|
5 |
+
try:
|
6 |
+
import triton
|
7 |
+
except ImportError:
|
8 |
+
pass
|
9 |
+
from pathlib import Path
|
10 |
+
from datetime import datetime
|
11 |
+
import gradio as gr
|
12 |
+
import random
|
13 |
+
import json
|
14 |
+
import wan
|
15 |
+
from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS, SUPPORTED_SIZES
|
16 |
+
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
|
17 |
+
from wan.utils.utils import cache_video
|
18 |
+
from wan.modules.attention import get_attention_modes
|
19 |
+
import torch
|
20 |
+
import gc
|
21 |
+
|
22 |
+
def _parse_args():
|
23 |
+
parser = argparse.ArgumentParser(
|
24 |
+
description="Generate a video from a text prompt or image using Gradio")
|
25 |
+
parser.add_argument(
|
26 |
+
"--ckpt_dir_720p",
|
27 |
+
type=str,
|
28 |
+
default=None,
|
29 |
+
help="The path to the checkpoint directory.")
|
30 |
+
parser.add_argument(
|
31 |
+
"--ckpt_dir_480p",
|
32 |
+
type=str,
|
33 |
+
default=None,
|
34 |
+
help="The path to the checkpoint directory.")
|
35 |
+
parser.add_argument(
|
36 |
+
"--prompt_extend_method",
|
37 |
+
type=str,
|
38 |
+
default="local_qwen",
|
39 |
+
choices=["dashscope", "local_qwen"],
|
40 |
+
help="The prompt extend method to use.")
|
41 |
+
parser.add_argument(
|
42 |
+
"--prompt_extend_model",
|
43 |
+
type=str,
|
44 |
+
default=None,
|
45 |
+
help="The prompt extend model to use.")
|
46 |
+
|
47 |
+
parser.add_argument(
|
48 |
+
"--quantize-transformer",
|
49 |
+
action="store_true",
|
50 |
+
help="On the fly 'transformer' quantization"
|
51 |
+
)
|
52 |
+
|
53 |
+
|
54 |
+
parser.add_argument(
|
55 |
+
"--lora-dir-i2v",
|
56 |
+
type=str,
|
57 |
+
default="loras_i2v",
|
58 |
+
help="Path to a directory that contains Loras for i2v"
|
59 |
+
)
|
60 |
+
|
61 |
+
|
62 |
+
parser.add_argument(
|
63 |
+
"--lora-dir",
|
64 |
+
type=str,
|
65 |
+
default="loras",
|
66 |
+
help="Path to a directory that contains Loras"
|
67 |
+
)
|
68 |
+
|
69 |
+
|
70 |
+
parser.add_argument(
|
71 |
+
"--lora-preset",
|
72 |
+
type=str,
|
73 |
+
default="",
|
74 |
+
help="Lora preset to preload"
|
75 |
+
)
|
76 |
+
|
77 |
+
parser.add_argument(
|
78 |
+
"--lora-preset-i2v",
|
79 |
+
type=str,
|
80 |
+
default="",
|
81 |
+
help="Lora preset to preload for i2v"
|
82 |
+
)
|
83 |
+
|
84 |
+
parser.add_argument(
|
85 |
+
"--profile",
|
86 |
+
type=str,
|
87 |
+
default=-1,
|
88 |
+
help="Profile No"
|
89 |
+
)
|
90 |
+
|
91 |
+
parser.add_argument(
|
92 |
+
"--verbose",
|
93 |
+
type=str,
|
94 |
+
default=1,
|
95 |
+
help="Verbose level"
|
96 |
+
)
|
97 |
+
|
98 |
+
parser.add_argument(
|
99 |
+
"--server-port",
|
100 |
+
type=str,
|
101 |
+
default=0,
|
102 |
+
help="Server port"
|
103 |
+
)
|
104 |
+
|
105 |
+
parser.add_argument(
|
106 |
+
"--server-name",
|
107 |
+
type=str,
|
108 |
+
default="",
|
109 |
+
help="Server name"
|
110 |
+
)
|
111 |
+
|
112 |
+
parser.add_argument(
|
113 |
+
"--open-browser",
|
114 |
+
action="store_true",
|
115 |
+
help="open browser"
|
116 |
+
)
|
117 |
+
|
118 |
+
parser.add_argument(
|
119 |
+
"--t2v",
|
120 |
+
action="store_true",
|
121 |
+
help="text to video mode"
|
122 |
+
)
|
123 |
+
|
124 |
+
parser.add_argument(
|
125 |
+
"--i2v",
|
126 |
+
action="store_true",
|
127 |
+
help="image to video mode"
|
128 |
+
)
|
129 |
+
|
130 |
+
parser.add_argument(
|
131 |
+
"--compile",
|
132 |
+
action="store_true",
|
133 |
+
help="Enable pytorch compilation"
|
134 |
+
)
|
135 |
+
|
136 |
+
# parser.add_argument(
|
137 |
+
# "--fast",
|
138 |
+
# action="store_true",
|
139 |
+
# help="use Fast model"
|
140 |
+
# )
|
141 |
+
|
142 |
+
# parser.add_argument(
|
143 |
+
# "--fastest",
|
144 |
+
# action="store_true",
|
145 |
+
# help="activate the best config"
|
146 |
+
# )
|
147 |
+
|
148 |
+
parser.add_argument(
|
149 |
+
"--attention",
|
150 |
+
type=str,
|
151 |
+
default="",
|
152 |
+
help="attention mode"
|
153 |
+
)
|
154 |
+
|
155 |
+
parser.add_argument(
|
156 |
+
"--vae-config",
|
157 |
+
type=str,
|
158 |
+
default="",
|
159 |
+
help="vae config mode"
|
160 |
+
)
|
161 |
+
|
162 |
+
|
163 |
+
args = parser.parse_args()
|
164 |
+
args.ckpt_dir_720p = "../ckpts" # os.path.join("ckpt")
|
165 |
+
args.ckpt_dir_480p = "../ckpts" # os.path.join("ckpt")
|
166 |
+
assert args.ckpt_dir_720p is not None or args.ckpt_dir_480p is not None, "Please specify at least one checkpoint directory."
|
167 |
+
|
168 |
+
return args
|
169 |
+
|
170 |
+
attention_modes_supported = get_attention_modes()
|
171 |
+
|
172 |
+
args = _parse_args()
|
173 |
+
args.flow_reverse = True
|
174 |
+
|
175 |
+
|
176 |
+
lock_ui_attention = False
|
177 |
+
lock_ui_transformer = False
|
178 |
+
lock_ui_compile = False
|
179 |
+
|
180 |
+
|
181 |
+
force_profile_no = int(args.profile)
|
182 |
+
verbose_level = int(args.verbose)
|
183 |
+
quantizeTransformer = args.quantize_transformer
|
184 |
+
|
185 |
+
transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_int8.safetensors"]
|
186 |
+
transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_720p_14B_quanto_int8.safetensors"]
|
187 |
+
text_encoder_choices = ["ckpts/models_t5_umt5-xxl-enc-bf16", "ckpts/models_t5_umt5-xxl-enc-quanto_int8.safetensors"]
|
188 |
+
|
189 |
+
server_config_filename = "gradio_config.json"
|
190 |
+
|
191 |
+
if not Path(server_config_filename).is_file():
|
192 |
+
server_config = {"attention_mode" : "auto",
|
193 |
+
"transformer_filename": transformer_choices_t2v[0],
|
194 |
+
"transformer_filename_i2v": transformer_choices_i2v[1], ########
|
195 |
+
"text_encoder_filename" : text_encoder_choices[1],
|
196 |
+
"compile" : "",
|
197 |
+
"default_ui": "t2v",
|
198 |
+
"vae_config": 0,
|
199 |
+
"profile" : profile_type.LowRAM_LowVRAM }
|
200 |
+
|
201 |
+
with open(server_config_filename, "w", encoding="utf-8") as writer:
|
202 |
+
writer.write(json.dumps(server_config))
|
203 |
+
else:
|
204 |
+
with open(server_config_filename, "r", encoding="utf-8") as reader:
|
205 |
+
text = reader.read()
|
206 |
+
server_config = json.loads(text)
|
207 |
+
|
208 |
+
|
209 |
+
transformer_filename_t2v = server_config["transformer_filename"]
|
210 |
+
transformer_filename_i2v = server_config.get("transformer_filename_i2v", transformer_choices_i2v[1]) ########
|
211 |
+
|
212 |
+
text_encoder_filename = server_config["text_encoder_filename"]
|
213 |
+
attention_mode = server_config["attention_mode"]
|
214 |
+
if len(args.attention)> 0:
|
215 |
+
if args.attention in ["auto", "sdpa", "sage", "sage2", "flash", "xformers"]:
|
216 |
+
attention_mode = args.attention
|
217 |
+
lock_ui_attention = True
|
218 |
+
else:
|
219 |
+
raise Exception(f"Unknown attention mode '{args.attention}'")
|
220 |
+
|
221 |
+
profile = force_profile_no if force_profile_no >=0 else server_config["profile"]
|
222 |
+
compile = server_config.get("compile", "")
|
223 |
+
vae_config = server_config.get("vae_config", 0)
|
224 |
+
if len(args.vae_config) > 0:
|
225 |
+
vae_config = int(args.vae_config)
|
226 |
+
|
227 |
+
default_ui = server_config.get("default_ui", "t2v")
|
228 |
+
use_image2video = default_ui != "t2v"
|
229 |
+
if args.t2v:
|
230 |
+
use_image2video = False
|
231 |
+
if args.i2v:
|
232 |
+
use_image2video = True
|
233 |
+
|
234 |
+
if use_image2video:
|
235 |
+
lora_dir =args.lora_dir_i2v
|
236 |
+
lora_preselected_preset = args.lora_preset_i2v
|
237 |
+
else:
|
238 |
+
lora_dir =args.lora_dir
|
239 |
+
lora_preselected_preset = args.lora_preset
|
240 |
+
|
241 |
+
default_tea_cache = 0
|
242 |
+
# if args.fast : #or args.fastest
|
243 |
+
# transformer_filename_t2v = transformer_choices_t2v[2]
|
244 |
+
# attention_mode="sage2" if "sage2" in attention_modes_supported else "sage"
|
245 |
+
# default_tea_cache = 0.15
|
246 |
+
# lock_ui_attention = True
|
247 |
+
# lock_ui_transformer = True
|
248 |
+
|
249 |
+
if args.compile: #args.fastest or
|
250 |
+
compile="transformer"
|
251 |
+
lock_ui_compile = True
|
252 |
+
|
253 |
+
|
254 |
+
#attention_mode="sage"
|
255 |
+
#attention_mode="sage2"
|
256 |
+
#attention_mode="flash"
|
257 |
+
#attention_mode="sdpa"
|
258 |
+
#attention_mode="xformers"
|
259 |
+
# compile = "transformer"
|
260 |
+
|
261 |
+
def download_models(transformer_filename, text_encoder_filename):
|
262 |
+
def computeList(filename):
|
263 |
+
pos = filename.rfind("/")
|
264 |
+
filename = filename[pos+1:]
|
265 |
+
return [filename]
|
266 |
+
|
267 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
268 |
+
repoId = "DeepBeepMeep/Wan2.1"
|
269 |
+
sourceFolderList = ["xlm-roberta-large", "", ]
|
270 |
+
fileList = [ [], ["Wan2.1_VAE.pth", "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" ] + computeList(text_encoder_filename) + computeList(transformer_filename) ]
|
271 |
+
targetRoot = "ckpts/"
|
272 |
+
for sourceFolder, files in zip(sourceFolderList,fileList ):
|
273 |
+
if len(files)==0:
|
274 |
+
if not Path(targetRoot + sourceFolder).exists():
|
275 |
+
snapshot_download(repo_id=repoId, allow_patterns=sourceFolder +"/*", local_dir= targetRoot)
|
276 |
+
else:
|
277 |
+
for onefile in files:
|
278 |
+
if len(sourceFolder) > 0:
|
279 |
+
if not os.path.isfile(targetRoot + sourceFolder + "/" + onefile ):
|
280 |
+
hf_hub_download(repo_id=repoId, filename=onefile, local_dir = targetRoot, subfolder=sourceFolder)
|
281 |
+
else:
|
282 |
+
if not os.path.isfile(targetRoot + onefile ):
|
283 |
+
hf_hub_download(repo_id=repoId, filename=onefile, local_dir = targetRoot)
|
284 |
+
|
285 |
+
|
286 |
+
offload.default_verboseLevel = verbose_level
|
287 |
+
|
288 |
+
download_models(transformer_filename_i2v if use_image2video else transformer_filename_t2v, text_encoder_filename)
|
289 |
+
|
290 |
+
def sanitize_file_name(file_name):
|
291 |
+
return file_name.replace("/","").replace("\\","").replace(":","").replace("|","").replace("?","").replace("<","").replace(">","").replace("\"","")
|
292 |
+
|
293 |
+
def extract_preset(lset_name, loras):
|
294 |
+
lset_name = sanitize_file_name(lset_name)
|
295 |
+
if not lset_name.endswith(".lset"):
|
296 |
+
lset_name_filename = os.path.join(lora_dir, lset_name + ".lset" )
|
297 |
+
else:
|
298 |
+
lset_name_filename = os.path.join(lora_dir, lset_name )
|
299 |
+
|
300 |
+
if not os.path.isfile(lset_name_filename):
|
301 |
+
raise gr.Error(f"Preset '{lset_name}' not found ")
|
302 |
+
|
303 |
+
with open(lset_name_filename, "r", encoding="utf-8") as reader:
|
304 |
+
text = reader.read()
|
305 |
+
lset = json.loads(text)
|
306 |
+
|
307 |
+
loras_choices_files = lset["loras"]
|
308 |
+
loras_choices = []
|
309 |
+
missing_loras = []
|
310 |
+
for lora_file in loras_choices_files:
|
311 |
+
loras_choice_no = loras.index(os.path.join(lora_dir, lora_file))
|
312 |
+
if loras_choice_no < 0:
|
313 |
+
missing_loras.append(lora_file)
|
314 |
+
else:
|
315 |
+
loras_choices.append(str(loras_choice_no))
|
316 |
+
|
317 |
+
if len(missing_loras) > 0:
|
318 |
+
raise gr.Error(f"Unable to apply Lora preset '{lset_name} because the following Loras files are missing: {missing_loras}")
|
319 |
+
|
320 |
+
loras_mult_choices = lset["loras_mult"]
|
321 |
+
return loras_choices, loras_mult_choices
|
322 |
+
|
323 |
+
def setup_loras(pipe, lora_dir, lora_preselected_preset, split_linear_modules_map = None):
|
324 |
+
loras =[]
|
325 |
+
loras_names = []
|
326 |
+
default_loras_choices = []
|
327 |
+
default_loras_multis_str = ""
|
328 |
+
loras_presets = []
|
329 |
+
|
330 |
+
from pathlib import Path
|
331 |
+
|
332 |
+
if lora_dir != None :
|
333 |
+
if not os.path.isdir(lora_dir):
|
334 |
+
raise Exception("--lora-dir should be a path to a directory that contains Loras")
|
335 |
+
|
336 |
+
default_lora_preset = ""
|
337 |
+
|
338 |
+
if lora_dir != None:
|
339 |
+
import glob
|
340 |
+
dir_loras = glob.glob( os.path.join(lora_dir , "*.sft") ) + glob.glob( os.path.join(lora_dir , "*.safetensors") )
|
341 |
+
dir_loras.sort()
|
342 |
+
loras += [element for element in dir_loras if element not in loras ]
|
343 |
+
|
344 |
+
dir_presets = glob.glob( os.path.join(lora_dir , "*.lset") )
|
345 |
+
dir_presets.sort()
|
346 |
+
loras_presets = [ Path(Path(file_path).parts[-1]).stem for file_path in dir_presets]
|
347 |
+
|
348 |
+
if len(loras) > 0:
|
349 |
+
loras_names = [ Path(lora).stem for lora in loras ]
|
350 |
+
offload.load_loras_into_model(pipe.transformer, loras, activate_all_loras=False, split_linear_modules_map = split_linear_modules_map) #lora_multiplier,
|
351 |
+
|
352 |
+
if len(lora_preselected_preset) > 0:
|
353 |
+
if not os.path.isfile(os.path.join(lora_dir, lora_preselected_preset + ".lset")):
|
354 |
+
raise Exception(f"Unknown preset '{lora_preselected_preset}'")
|
355 |
+
default_lora_preset = lora_preselected_preset
|
356 |
+
default_loras_choices, default_loras_multis_str= extract_preset(default_lora_preset, loras)
|
357 |
+
|
358 |
+
return loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset, loras_presets
|
359 |
+
|
360 |
+
|
361 |
+
def load_t2v_model(model_filename, value):
|
362 |
+
|
363 |
+
cfg = WAN_CONFIGS['t2v-14B']
|
364 |
+
# cfg = WAN_CONFIGS['t2v-1.3B']
|
365 |
+
print("load t2v model...")
|
366 |
+
|
367 |
+
wan_model = wan.WanT2V(
|
368 |
+
config=cfg,
|
369 |
+
checkpoint_dir="ckpts",
|
370 |
+
device_id=0,
|
371 |
+
rank=0,
|
372 |
+
t5_fsdp=False,
|
373 |
+
dit_fsdp=False,
|
374 |
+
use_usp=False,
|
375 |
+
model_filename=model_filename,
|
376 |
+
text_encoder_filename= text_encoder_filename
|
377 |
+
)
|
378 |
+
|
379 |
+
pipe = {"transformer": wan_model.model, "text_encoder" : wan_model.text_encoder.model, "vae": wan_model.vae.model }
|
380 |
+
|
381 |
+
return wan_model, pipe
|
382 |
+
|
383 |
+
def load_i2v_model(model_filename, value):
|
384 |
+
|
385 |
+
|
386 |
+
if value == '720P':
|
387 |
+
print("load 14B-720P i2v model...")
|
388 |
+
cfg = WAN_CONFIGS['i2v-14B']
|
389 |
+
wan_model = wan.WanI2V(
|
390 |
+
config=cfg,
|
391 |
+
checkpoint_dir="ckpts",
|
392 |
+
device_id=0,
|
393 |
+
rank=0,
|
394 |
+
t5_fsdp=False,
|
395 |
+
dit_fsdp=False,
|
396 |
+
use_usp=False,
|
397 |
+
i2v720p= True,
|
398 |
+
model_filename=model_filename,
|
399 |
+
text_encoder_filename=text_encoder_filename
|
400 |
+
)
|
401 |
+
pipe = {"transformer": wan_model.model, "text_encoder" : wan_model.text_encoder.model, "text_encoder_2": wan_model.clip.model, "vae": wan_model.vae.model } #
|
402 |
+
|
403 |
+
if value == '480P':
|
404 |
+
print("load 14B-480P i2v model...")
|
405 |
+
cfg = WAN_CONFIGS['i2v-14B']
|
406 |
+
wan_model = wan.WanI2V(
|
407 |
+
config=cfg,
|
408 |
+
checkpoint_dir="ckpts",
|
409 |
+
device_id=0,
|
410 |
+
rank=0,
|
411 |
+
t5_fsdp=False,
|
412 |
+
dit_fsdp=False,
|
413 |
+
use_usp=False,
|
414 |
+
i2v720p= False,
|
415 |
+
model_filename=model_filename,
|
416 |
+
text_encoder_filename=text_encoder_filename
|
417 |
+
|
418 |
+
)
|
419 |
+
pipe = {"transformer": wan_model.model, "text_encoder" : wan_model.text_encoder.model, "text_encoder_2": wan_model.clip.model, "vae": wan_model.vae.model } #
|
420 |
+
|
421 |
+
return wan_model, pipe
|
422 |
+
|
423 |
+
def load_models(i2v, lora_dir, lora_preselected_preset ):
|
424 |
+
download_models(transformer_filename_i2v if i2v else transformer_filename_t2v, text_encoder_filename)
|
425 |
+
|
426 |
+
if i2v:
|
427 |
+
res720P= "720p" in transformer_filename_i2v
|
428 |
+
wan_model, pipe = load_i2v_model(transformer_filename_i2v,"720P" if res720P else "480P")
|
429 |
+
else:
|
430 |
+
wan_model, pipe = load_t2v_model(transformer_filename_t2v,"")
|
431 |
+
|
432 |
+
kwargs = { "extraModelsToQuantize": None}
|
433 |
+
if profile == 2 or profile == 4:
|
434 |
+
kwargs["budgets"] = { "transformer" : 100, "*" : 3000 }
|
435 |
+
|
436 |
+
loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset, loras_presets = setup_loras(pipe, lora_dir, lora_preselected_preset, None)
|
437 |
+
offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, **kwargs)
|
438 |
+
|
439 |
+
|
440 |
+
return wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset, loras_presets
|
441 |
+
|
442 |
+
wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset, loras_presets = load_models(use_image2video, lora_dir, lora_preselected_preset )
|
443 |
+
gen_in_progress = False
|
444 |
+
|
445 |
+
def get_auto_attention():
|
446 |
+
for attn in ["sage2","sage","sdpa"]:
|
447 |
+
if attn in attention_modes_supported:
|
448 |
+
return attn
|
449 |
+
return "sdpa"
|
450 |
+
|
451 |
+
def get_default_flow(model_filename):
|
452 |
+
return 3.0 if "480p" in model_filename else 5.0
|
453 |
+
|
454 |
+
def generate_header(model_filename, compile, attention_mode):
|
455 |
+
header = "<H2 ALIGN=CENTER><SPAN> ----------------- "
|
456 |
+
|
457 |
+
if "image" in model_filename:
|
458 |
+
model_name = "Wan2.1 image2video"
|
459 |
+
model_name += "720p" if "720p" in model_filename else "480"
|
460 |
+
else:
|
461 |
+
model_name = "Wan2.1 text2video"
|
462 |
+
model_name += "14B" if "14B" in model_filename else "1.3B"
|
463 |
+
|
464 |
+
header += model_name
|
465 |
+
header += " (attention mode: " + (attention_mode if attention_mode!="auto" else "auto/" + get_auto_attention() )
|
466 |
+
if attention_mode not in attention_modes_supported:
|
467 |
+
header += " -NOT INSTALLED-"
|
468 |
+
|
469 |
+
if compile:
|
470 |
+
header += ", pytorch compilation ON"
|
471 |
+
header += ") -----------------</SPAN></H2>"
|
472 |
+
|
473 |
+
return header
|
474 |
+
|
475 |
+
def apply_changes( state,
|
476 |
+
transformer_t2v_choice,
|
477 |
+
transformer_i2v_choice,
|
478 |
+
text_encoder_choice,
|
479 |
+
attention_choice,
|
480 |
+
compile_choice,
|
481 |
+
profile_choice,
|
482 |
+
vae_config_choice,
|
483 |
+
default_ui_choice ="t2v",
|
484 |
+
):
|
485 |
+
|
486 |
+
if gen_in_progress:
|
487 |
+
yield "<DIV ALIGN=CENTER>Unable to change config when a generation is in progress</DIV>"
|
488 |
+
return
|
489 |
+
global offloadobj, wan_model, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset, loras_presets
|
490 |
+
server_config = {"attention_mode" : attention_choice,
|
491 |
+
"transformer_filename": transformer_choices_t2v[transformer_t2v_choice],
|
492 |
+
"transformer_filename_i2v": transformer_choices_i2v[transformer_i2v_choice], ##########
|
493 |
+
"text_encoder_filename" : text_encoder_choices[text_encoder_choice],
|
494 |
+
"compile" : compile_choice,
|
495 |
+
"profile" : profile_choice,
|
496 |
+
"vae_config" : vae_config_choice,
|
497 |
+
"default_ui" : default_ui_choice,
|
498 |
+
}
|
499 |
+
|
500 |
+
if Path(server_config_filename).is_file():
|
501 |
+
with open(server_config_filename, "r", encoding="utf-8") as reader:
|
502 |
+
text = reader.read()
|
503 |
+
old_server_config = json.loads(text)
|
504 |
+
if lock_ui_transformer:
|
505 |
+
server_config["transformer_filename"] = old_server_config["transformer_filename"]
|
506 |
+
server_config["transformer_filename_i2v"] = old_server_config["transformer_filename_i2v"]
|
507 |
+
if lock_ui_attention:
|
508 |
+
server_config["attention_mode"] = old_server_config["attention_mode"]
|
509 |
+
if lock_ui_compile:
|
510 |
+
server_config["compile"] = old_server_config["compile"]
|
511 |
+
|
512 |
+
with open(server_config_filename, "w", encoding="utf-8") as writer:
|
513 |
+
writer.write(json.dumps(server_config))
|
514 |
+
|
515 |
+
changes = []
|
516 |
+
for k, v in server_config.items():
|
517 |
+
v_old = old_server_config.get(k, None)
|
518 |
+
if v != v_old:
|
519 |
+
changes.append(k)
|
520 |
+
|
521 |
+
state["config_changes"] = changes
|
522 |
+
state["config_new"] = server_config
|
523 |
+
state["config_old"] = old_server_config
|
524 |
+
|
525 |
+
global attention_mode, profile, compile, transformer_filename_t2v, transformer_filename_i2v, text_encoder_filename, vae_config
|
526 |
+
attention_mode = server_config["attention_mode"]
|
527 |
+
profile = server_config["profile"]
|
528 |
+
compile = server_config["compile"]
|
529 |
+
transformer_filename_t2v = server_config["transformer_filename"]
|
530 |
+
transformer_filename_i2v = server_config["transformer_filename_i2v"]
|
531 |
+
text_encoder_filename = server_config["text_encoder_filename"]
|
532 |
+
vae_config = server_config["vae_config"]
|
533 |
+
|
534 |
+
if all(change in ["attention_mode", "vae_config", "default_ui"] for change in changes ):
|
535 |
+
if "attention_mode" in changes:
|
536 |
+
pass
|
537 |
+
|
538 |
+
else:
|
539 |
+
wan_model = None
|
540 |
+
offloadobj.release()
|
541 |
+
offloadobj = None
|
542 |
+
yield "<DIV ALIGN=CENTER>Please wait while the new configuration is being applied</DIV>"
|
543 |
+
|
544 |
+
wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset, loras_presets = load_models(use_image2video, lora_dir, lora_preselected_preset )
|
545 |
+
|
546 |
+
|
547 |
+
yield "<DIV ALIGN=CENTER>The new configuration has been succesfully applied</DIV>"
|
548 |
+
|
549 |
+
# return "<DIV ALIGN=CENTER>New Config file created. Please restart the Gradio Server</DIV>"
|
550 |
+
|
551 |
+
def update_defaults(state, num_inference_steps,flow_shift):
|
552 |
+
if "config_changes" not in state:
|
553 |
+
return get_default_flow("")
|
554 |
+
changes = state["config_changes"]
|
555 |
+
server_config = state["config_new"]
|
556 |
+
old_server_config = state["config_old"]
|
557 |
+
|
558 |
+
if use_image2video:
|
559 |
+
old_is_14B = "14B" in server_config["transformer_filename"]
|
560 |
+
new_is_14B = "14B" in old_server_config["transformer_filename"]
|
561 |
+
|
562 |
+
trans_file = server_config["transformer_filename"]
|
563 |
+
# if old_is_14B != new_is_14B:
|
564 |
+
# num_inference_steps, flow_shift = get_default_flow(trans_file)
|
565 |
+
else:
|
566 |
+
old_is_720P = "720P" in server_config["transformer_filename_i2v"]
|
567 |
+
new_is_720P = "720P" in old_server_config["transformer_filename_i2v"]
|
568 |
+
trans_file = server_config["transformer_filename_i2v"]
|
569 |
+
if old_is_720P != new_is_720P:
|
570 |
+
num_inference_steps, flow_shift = get_default_flow(trans_file)
|
571 |
+
|
572 |
+
header = generate_header(trans_file, server_config["compile"], server_config["attention_mode"] )
|
573 |
+
return num_inference_steps, flow_shift, header
|
574 |
+
|
575 |
+
|
576 |
+
from moviepy.editor import ImageSequenceClip
|
577 |
+
import numpy as np
|
578 |
+
|
579 |
+
def save_video(final_frames, output_path, fps=24):
|
580 |
+
assert final_frames.ndim == 4 and final_frames.shape[3] == 3, f"invalid shape: {final_frames} (need t h w c)"
|
581 |
+
if final_frames.dtype != np.uint8:
|
582 |
+
final_frames = (final_frames * 255).astype(np.uint8)
|
583 |
+
ImageSequenceClip(list(final_frames), fps=fps).write_videofile(output_path, verbose= False, logger = None)
|
584 |
+
|
585 |
+
def build_callback(state, pipe, progress, status, num_inference_steps):
|
586 |
+
def callback(step_idx, latents):
|
587 |
+
step_idx += 1
|
588 |
+
if state.get("abort", False):
|
589 |
+
# pipe._interrupt = True
|
590 |
+
status_msg = status + " - Aborting"
|
591 |
+
elif step_idx == num_inference_steps:
|
592 |
+
status_msg = status + " - VAE Decoding"
|
593 |
+
else:
|
594 |
+
status_msg = status + " - Denoising"
|
595 |
+
|
596 |
+
progress( (step_idx , num_inference_steps) , status_msg , num_inference_steps)
|
597 |
+
|
598 |
+
return callback
|
599 |
+
|
600 |
+
def abort_generation(state):
|
601 |
+
if "in_progress" in state:
|
602 |
+
state["abort"] = True
|
603 |
+
wan_model._interrupt= True
|
604 |
+
return gr.Button(interactive= False)
|
605 |
+
else:
|
606 |
+
return gr.Button(interactive= True)
|
607 |
+
|
608 |
+
def refresh_gallery(state):
|
609 |
+
file_list = state.get("file_list", None)
|
610 |
+
return file_list
|
611 |
+
|
612 |
+
def finalize_gallery(state):
|
613 |
+
choice = 0
|
614 |
+
if "in_progress" in state:
|
615 |
+
del state["in_progress"]
|
616 |
+
choice = state.get("selected",0)
|
617 |
+
|
618 |
+
time.sleep(0.2)
|
619 |
+
global gen_in_progress
|
620 |
+
gen_in_progress = False
|
621 |
+
return gr.Gallery(selected_index=choice), gr.Button(interactive= True)
|
622 |
+
|
623 |
+
def select_video(state , event_data: gr.EventData):
|
624 |
+
data= event_data._data
|
625 |
+
if data!=None:
|
626 |
+
state["selected"] = data.get("index",0)
|
627 |
+
return
|
628 |
+
|
629 |
+
def expand_slist(slist, num_inference_steps ):
|
630 |
+
new_slist= []
|
631 |
+
inc = len(slist) / num_inference_steps
|
632 |
+
pos = 0
|
633 |
+
for i in range(num_inference_steps):
|
634 |
+
new_slist.append(slist[ int(pos)])
|
635 |
+
pos += inc
|
636 |
+
return new_slist
|
637 |
+
|
638 |
+
|
639 |
+
def generate_video(
|
640 |
+
prompt,
|
641 |
+
negative_prompt,
|
642 |
+
resolution,
|
643 |
+
video_length,
|
644 |
+
seed,
|
645 |
+
num_inference_steps,
|
646 |
+
guidance_scale,
|
647 |
+
flow_shift,
|
648 |
+
embedded_guidance_scale,
|
649 |
+
repeat_generation,
|
650 |
+
tea_cache,
|
651 |
+
loras_choices,
|
652 |
+
loras_mult_choices,
|
653 |
+
image_to_continue,
|
654 |
+
video_to_continue,
|
655 |
+
max_frames,
|
656 |
+
RIFLEx_setting,
|
657 |
+
state,
|
658 |
+
progress=gr.Progress() #track_tqdm= True
|
659 |
+
|
660 |
+
):
|
661 |
+
|
662 |
+
from PIL import Image
|
663 |
+
import numpy as np
|
664 |
+
import tempfile
|
665 |
+
|
666 |
+
|
667 |
+
if wan_model == None:
|
668 |
+
raise gr.Error("Unable to generate a Video while a new configuration is being applied.")
|
669 |
+
if attention_mode == "auto":
|
670 |
+
attn = get_auto_attention()
|
671 |
+
elif attention_mode in attention_modes_supported:
|
672 |
+
attn = attention_mode
|
673 |
+
else:
|
674 |
+
raise gr.Error(f"You have selected attention mode '{attention_mode}'. However it is not installed on your system. You should either install it or switch to the default 'sdpa' attention.")
|
675 |
+
|
676 |
+
width, height = resolution.split("x")
|
677 |
+
width, height = int(width), int(height)
|
678 |
+
|
679 |
+
|
680 |
+
if use_image2video:
|
681 |
+
if "480p" in transformer_filename_i2v and width * height > 848*480:
|
682 |
+
raise gr.Error("You must use the 720P image to video model to generate videos with a resolution equivalent to 720P")
|
683 |
+
|
684 |
+
resolution = str(width) + "*" + str(height)
|
685 |
+
if resolution not in ['720*1280', '1280*720', '480*832', '832*480']:
|
686 |
+
raise gr.Error(f"Resolution {resolution} not supported by image 2 video")
|
687 |
+
|
688 |
+
|
689 |
+
else:
|
690 |
+
if "1.3B" in transformer_filename_t2v and width * height > 848*480:
|
691 |
+
raise gr.Error("You must use the 14B text to video model to generate videos with a resolution equivalent to 720P")
|
692 |
+
|
693 |
+
offload.shared_state["_vae"] = vae_config
|
694 |
+
offload.shared_state["_vae_threshold"] = 0.9* torch.cuda.get_device_properties(0).total_memory
|
695 |
+
|
696 |
+
offload.shared_state["_attention"] = attn
|
697 |
+
|
698 |
+
global gen_in_progress
|
699 |
+
gen_in_progress = True
|
700 |
+
temp_filename = None
|
701 |
+
if use_image2video:
|
702 |
+
if image_to_continue is not None:
|
703 |
+
pass
|
704 |
+
|
705 |
+
elif video_to_continue != None and len(video_to_continue) >0 :
|
706 |
+
input_image_or_video_path = video_to_continue
|
707 |
+
# pipeline.num_input_frames = max_frames
|
708 |
+
# pipeline.max_frames = max_frames
|
709 |
+
else:
|
710 |
+
return
|
711 |
+
else:
|
712 |
+
input_image_or_video_path = None
|
713 |
+
|
714 |
+
|
715 |
+
if len(loras) > 0:
|
716 |
+
def is_float(element: any) -> bool:
|
717 |
+
if element is None:
|
718 |
+
return False
|
719 |
+
try:
|
720 |
+
float(element)
|
721 |
+
return True
|
722 |
+
except ValueError:
|
723 |
+
return False
|
724 |
+
list_mult_choices_nums = []
|
725 |
+
if len(loras_mult_choices) > 0:
|
726 |
+
list_mult_choices_str = loras_mult_choices.split(" ")
|
727 |
+
for i, mult in enumerate(list_mult_choices_str):
|
728 |
+
mult = mult.strip()
|
729 |
+
if "," in mult:
|
730 |
+
multlist = mult.split(",")
|
731 |
+
slist = []
|
732 |
+
for smult in multlist:
|
733 |
+
if not is_float(smult):
|
734 |
+
raise gr.Error(f"Lora sub value no {i+1} ({smult}) in Multiplier definition '{multlist}' is invalid")
|
735 |
+
slist.append(float(smult))
|
736 |
+
slist = expand_slist(slist, num_inference_steps )
|
737 |
+
list_mult_choices_nums.append(slist)
|
738 |
+
else:
|
739 |
+
if not is_float(mult):
|
740 |
+
raise gr.Error(f"Lora Multiplier no {i+1} ({mult}) is invalid")
|
741 |
+
list_mult_choices_nums.append(float(mult))
|
742 |
+
if len(list_mult_choices_nums ) < len(loras_choices):
|
743 |
+
list_mult_choices_nums += [1.0] * ( len(loras_choices) - len(list_mult_choices_nums ) )
|
744 |
+
|
745 |
+
offload.activate_loras(wan_model.model, loras_choices, list_mult_choices_nums)
|
746 |
+
|
747 |
+
seed = None if seed == -1 else seed
|
748 |
+
# negative_prompt = "" # not applicable in the inference
|
749 |
+
|
750 |
+
if "abort" in state:
|
751 |
+
del state["abort"]
|
752 |
+
state["in_progress"] = True
|
753 |
+
state["selected"] = 0
|
754 |
+
|
755 |
+
enable_riflex = RIFLEx_setting == 0 and video_length > (5* 24) or RIFLEx_setting == 1
|
756 |
+
# VAE Tiling
|
757 |
+
device_mem_capacity = torch.cuda.get_device_properties(0).total_memory / 1048576
|
758 |
+
|
759 |
+
|
760 |
+
# TeaCache
|
761 |
+
trans = wan_model.model
|
762 |
+
trans.enable_teacache = tea_cache > 0
|
763 |
+
|
764 |
+
import random
|
765 |
+
if seed == None or seed <0:
|
766 |
+
seed = random.randint(0, 999999999)
|
767 |
+
|
768 |
+
file_list = []
|
769 |
+
state["file_list"] = file_list
|
770 |
+
from einops import rearrange
|
771 |
+
save_path = os.path.join(os.getcwd(), "gradio_outputs")
|
772 |
+
os.makedirs(save_path, exist_ok=True)
|
773 |
+
prompts = prompt.replace("\r", "").split("\n")
|
774 |
+
video_no = 0
|
775 |
+
total_video = repeat_generation * len(prompts)
|
776 |
+
abort = False
|
777 |
+
start_time = time.time()
|
778 |
+
for prompt in prompts:
|
779 |
+
for _ in range(repeat_generation):
|
780 |
+
if abort:
|
781 |
+
break
|
782 |
+
|
783 |
+
if trans.enable_teacache:
|
784 |
+
trans.num_steps = num_inference_steps
|
785 |
+
trans.cnt = 0
|
786 |
+
trans.rel_l1_thresh = tea_cache #0.15 # 0.1 for 1.6x speedup, 0.15 for 2.1x speedup
|
787 |
+
trans.accumulated_rel_l1_distance = 0
|
788 |
+
trans.previous_modulated_input = None
|
789 |
+
trans.previous_residual = None
|
790 |
+
|
791 |
+
video_no += 1
|
792 |
+
status = f"Video {video_no}/{total_video}"
|
793 |
+
progress(0, desc=status + " - Encoding Prompt" )
|
794 |
+
|
795 |
+
callback = build_callback(state, trans, progress, status, num_inference_steps)
|
796 |
+
|
797 |
+
|
798 |
+
gc.collect()
|
799 |
+
torch.cuda.empty_cache()
|
800 |
+
try:
|
801 |
+
if use_image2video:
|
802 |
+
samples = wan_model.generate(
|
803 |
+
prompt,
|
804 |
+
image_to_continue,
|
805 |
+
frame_num=(video_length // 4)* 4 + 1,
|
806 |
+
max_area=MAX_AREA_CONFIGS[resolution],
|
807 |
+
shift=flow_shift,
|
808 |
+
sampling_steps=num_inference_steps,
|
809 |
+
guide_scale=guidance_scale,
|
810 |
+
n_prompt=negative_prompt,
|
811 |
+
seed=seed,
|
812 |
+
offload_model=False,
|
813 |
+
callback=callback
|
814 |
+
)
|
815 |
+
|
816 |
+
else:
|
817 |
+
samples = wan_model.generate(
|
818 |
+
prompt,
|
819 |
+
frame_num=(video_length // 4)* 4 + 1,
|
820 |
+
size=(width, height),
|
821 |
+
shift=flow_shift,
|
822 |
+
sampling_steps=num_inference_steps,
|
823 |
+
guide_scale=guidance_scale,
|
824 |
+
n_prompt=negative_prompt,
|
825 |
+
seed=seed,
|
826 |
+
offload_model=False,
|
827 |
+
callback=callback
|
828 |
+
)
|
829 |
+
except:
|
830 |
+
gen_in_progress = False
|
831 |
+
if temp_filename!= None and os.path.isfile(temp_filename):
|
832 |
+
os.remove(temp_filename)
|
833 |
+
offload.last_offload_obj.unload_all()
|
834 |
+
# if compile:
|
835 |
+
# cache_size = torch._dynamo.config.cache_size_limit
|
836 |
+
# torch.compiler.reset()
|
837 |
+
# torch._dynamo.config.cache_size_limit = cache_size
|
838 |
+
|
839 |
+
gc.collect()
|
840 |
+
torch.cuda.empty_cache()
|
841 |
+
raise gr.Error("The generation of the video has encountered an error: it is likely that you have unsufficient VRAM and you should therefore reduce the video resolution or its number of frames.")
|
842 |
+
|
843 |
+
|
844 |
+
if samples != None:
|
845 |
+
samples = samples.to("cpu")
|
846 |
+
offload.last_offload_obj.unload_all()
|
847 |
+
gc.collect()
|
848 |
+
torch.cuda.empty_cache()
|
849 |
+
|
850 |
+
if samples == None:
|
851 |
+
end_time = time.time()
|
852 |
+
abort = True
|
853 |
+
yield f"Video generation was aborted. Total Generation Time: {end_time-start_time:.1f}s"
|
854 |
+
else:
|
855 |
+
sample = samples.cpu()
|
856 |
+
# video = rearrange(sample.cpu().numpy(), "c t h w -> t h w c")
|
857 |
+
|
858 |
+
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss")
|
859 |
+
file_name = f"{time_flag}_seed{seed}_{prompt[:100].replace('/','').strip()}.mp4".replace(':',' ').replace('\\',' ')
|
860 |
+
video_path = os.path.join(os.getcwd(), "gradio_outputs", file_name)
|
861 |
+
cache_video(
|
862 |
+
tensor=sample[None],
|
863 |
+
save_file=video_path,
|
864 |
+
fps=16,
|
865 |
+
nrow=1,
|
866 |
+
normalize=True,
|
867 |
+
value_range=(-1, 1))
|
868 |
+
|
869 |
+
print(f"New video saved to Path: "+video_path)
|
870 |
+
file_list.append(video_path)
|
871 |
+
if video_no < total_video:
|
872 |
+
yield status
|
873 |
+
else:
|
874 |
+
end_time = time.time()
|
875 |
+
yield f"Total Generation Time: {end_time-start_time:.1f}s"
|
876 |
+
seed += 1
|
877 |
+
|
878 |
+
if temp_filename!= None and os.path.isfile(temp_filename):
|
879 |
+
os.remove(temp_filename)
|
880 |
+
gen_in_progress = False
|
881 |
+
|
882 |
+
new_preset_msg = "Enter a Name for a Lora Preset or Choose One Above"
|
883 |
+
|
884 |
+
def save_lset(lset_name, loras_choices, loras_mult_choices):
|
885 |
+
global loras_presets
|
886 |
+
|
887 |
+
if len(lset_name) == 0 or lset_name== new_preset_msg:
|
888 |
+
gr.Info("Please enter a name for the preset")
|
889 |
+
lset_choices =[("Please enter a name for a Lora Preset","")]
|
890 |
+
else:
|
891 |
+
lset_name = sanitize_file_name(lset_name)
|
892 |
+
|
893 |
+
loras_choices_files = [ Path(loras[int(choice_no)]).parts[-1] for choice_no in loras_choices ]
|
894 |
+
lset = {"loras" : loras_choices_files, "loras_mult" : loras_mult_choices}
|
895 |
+
lset_name_filename = lset_name + ".lset"
|
896 |
+
full_lset_name_filename = os.path.join(lora_dir, lset_name_filename)
|
897 |
+
|
898 |
+
with open(full_lset_name_filename, "w", encoding="utf-8") as writer:
|
899 |
+
writer.write(json.dumps(lset))
|
900 |
+
|
901 |
+
if lset_name in loras_presets:
|
902 |
+
gr.Info(f"Lora Preset '{lset_name}' has been updated")
|
903 |
+
else:
|
904 |
+
gr.Info(f"Lora Preset '{lset_name}' has been created")
|
905 |
+
loras_presets.append(Path(Path(lset_name_filename).parts[-1]).stem )
|
906 |
+
lset_choices = [ ( preset, preset) for preset in loras_presets ]
|
907 |
+
lset_choices.append( (new_preset_msg, ""))
|
908 |
+
|
909 |
+
return gr.Dropdown(choices=lset_choices, value= lset_name)
|
910 |
+
|
911 |
+
def delete_lset(lset_name):
|
912 |
+
global loras_presets
|
913 |
+
lset_name_filename = os.path.join(lora_dir, sanitize_file_name(lset_name) + ".lset" )
|
914 |
+
if len(lset_name) > 0 and lset_name != new_preset_msg:
|
915 |
+
if not os.path.isfile(lset_name_filename):
|
916 |
+
raise gr.Error(f"Preset '{lset_name}' not found ")
|
917 |
+
os.remove(lset_name_filename)
|
918 |
+
pos = loras_presets.index(lset_name)
|
919 |
+
gr.Info(f"Lora Preset '{lset_name}' has been deleted")
|
920 |
+
loras_presets.remove(lset_name)
|
921 |
+
else:
|
922 |
+
pos = len(loras_presets)
|
923 |
+
gr.Info(f"Choose a Preset to delete")
|
924 |
+
|
925 |
+
lset_choices = [ (preset, preset) for preset in loras_presets]
|
926 |
+
lset_choices.append((new_preset_msg, ""))
|
927 |
+
return gr.Dropdown(choices=lset_choices, value= lset_choices[pos][1])
|
928 |
+
|
929 |
+
def apply_lset(lset_name, loras_choices, loras_mult_choices):
|
930 |
+
|
931 |
+
if len(lset_name) == 0 or lset_name== new_preset_msg:
|
932 |
+
gr.Info("Please choose a preset in the list or create one")
|
933 |
+
else:
|
934 |
+
loras_choices, loras_mult_choices= extract_preset(lset_name, loras)
|
935 |
+
gr.Info(f"Lora Preset '{lset_name}' has been applied")
|
936 |
+
|
937 |
+
return loras_choices, loras_mult_choices
|
938 |
+
|
939 |
+
def create_demo():
|
940 |
+
|
941 |
+
default_inference_steps = 30
|
942 |
+
default_flow_shift = get_default_flow(transformer_filename_i2v if use_image2video else transformer_filename_t2v)
|
943 |
+
|
944 |
+
with gr.Blocks() as demo:
|
945 |
+
state = gr.State({})
|
946 |
+
|
947 |
+
if use_image2video:
|
948 |
+
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v1 - AI Image To Video Generator (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A> / <A HREF='https://github.com/Wan-Video/Wan2.1'>Original by Alibaba</A>)</H1></div>")
|
949 |
+
else:
|
950 |
+
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v1 - AI Text To Video Generator (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A> / <A HREF='https://github.com/Wan-Video/Wan2.1'>Original by Alibaba</A>)</H1></div>")
|
951 |
+
|
952 |
+
gr.Markdown("<FONT SIZE=3>With this first release of Wan 2.1GP by <B>DeepBeepMeep</B> the VRAM requirements have been divided by more than 2 with no quality loss</FONT>")
|
953 |
+
|
954 |
+
if use_image2video and False:
|
955 |
+
pass
|
956 |
+
else:
|
957 |
+
gr.Markdown("The VRAM requirements will depend greatly of the resolution and the duration of the video, for instance : 24 GB of VRAM (RTX 3090 / RTX 4090), the limits are as follows:")
|
958 |
+
gr.Markdown("- 848 x 480 with a 14B model: 80 frames (5s) : 8 GB of VRAM")
|
959 |
+
gr.Markdown("- 848 x 480 with the 1.3B model: 80 frames (5s) : 5 GB of VRAM")
|
960 |
+
gr.Markdown("- 1280 x 720 with a 14B model: 192 frames (8s): 11 GB of VRAM")
|
961 |
+
gr.Markdown("Note that the VAE stages (encoding / decoding at image2video ) or just the decoding at text2video will create a temporary VRAM peak (up to 12GB for 420P and 22 GB for 720P)")
|
962 |
+
|
963 |
+
gr.Markdown("Please note that if your turn on compilation, the first generation step of the first video generation will be slow due to the compilation. Therefore all your tests should be done with compilation turned off.")
|
964 |
+
|
965 |
+
|
966 |
+
# css = """<STYLE>
|
967 |
+
# h2 { width: 100%; text-align: center; border-bottom: 1px solid #000; line-height: 0.1em; margin: 10px 0 20px; }
|
968 |
+
# h2 span {background:#fff; padding:0 10px; }</STYLE>"""
|
969 |
+
# gr.HTML(css)
|
970 |
+
|
971 |
+
header = gr.Markdown(generate_header(transformer_filename_i2v if use_image2video else transformer_filename_t2v, compile, attention_mode) )
|
972 |
+
|
973 |
+
with gr.Accordion("Video Engine Configuration - click here to change it", open = False):
|
974 |
+
gr.Markdown("For the changes to be effective you will need to restart the gradio_server. Some choices below may be locked if the app has been launched by specifying a config preset.")
|
975 |
+
|
976 |
+
with gr.Column():
|
977 |
+
index = transformer_choices_t2v.index(transformer_filename_t2v)
|
978 |
+
index = 0 if index ==0 else index
|
979 |
+
transformer_t2v_choice = gr.Dropdown(
|
980 |
+
choices=[
|
981 |
+
("WAN 2.1 1.3B Text to Video 16 bits - the small model for fast generations with low VRAM requirements", 0),
|
982 |
+
("WAN 2.1 14B Text to Video 16 bits - the default engine in its original glory, offers a slightly better image quality but slower and requires more RAM", 1),
|
983 |
+
("WAN 2.1 14B Text to Video quantized to 8 bits (recommended) - the default engine but quantized", 2),
|
984 |
+
],
|
985 |
+
value= index,
|
986 |
+
label="Transformer model for Text to Video",
|
987 |
+
interactive= not lock_ui_transformer,
|
988 |
+
visible=not use_image2video
|
989 |
+
)
|
990 |
+
|
991 |
+
index = transformer_choices_i2v.index(transformer_filename_i2v)
|
992 |
+
index = 0 if index ==0 else index
|
993 |
+
transformer_i2v_choice = gr.Dropdown(
|
994 |
+
choices=[
|
995 |
+
("WAN 2.1 - 480p 14B Image to Video 16 bits - the default engine in its original glory, offers a slightly better image quality but slower and requires more RAM", 0),
|
996 |
+
("WAN 2.1 - 480p 14B Image to Video quantized to 8 bits (recommended) - the default engine but quantized", 1),
|
997 |
+
("WAN 2.1 - 720p 14B Image to Video 16 bits - the default engine in its original glory, offers a slightly better image quality but slower and requires more RAM", 2),
|
998 |
+
("WAN 2.1 - 720p 14B Image to Video quantized to 8 bits (recommended) - the default engine but quantized", 3),
|
999 |
+
],
|
1000 |
+
value= index,
|
1001 |
+
label="Transformer model for Image to Video",
|
1002 |
+
interactive= not lock_ui_transformer,
|
1003 |
+
visible = use_image2video, ###############
|
1004 |
+
)
|
1005 |
+
|
1006 |
+
index = text_encoder_choices.index(text_encoder_filename)
|
1007 |
+
index = 0 if index ==0 else index
|
1008 |
+
|
1009 |
+
text_encoder_choice = gr.Dropdown(
|
1010 |
+
choices=[
|
1011 |
+
("UMT5 XXL 16 bits - unquantized text encoder, better quality uses more RAM", 0),
|
1012 |
+
("UMT5 XXL quantized to 8 bits - quantized text encoder, slightly worse quality but uses less RAM", 1),
|
1013 |
+
],
|
1014 |
+
value= index,
|
1015 |
+
label="Text Encoder model"
|
1016 |
+
)
|
1017 |
+
def check(mode):
|
1018 |
+
if not mode in attention_modes_supported:
|
1019 |
+
return " (NOT INSTALLED)"
|
1020 |
+
else:
|
1021 |
+
return ""
|
1022 |
+
attention_choice = gr.Dropdown(
|
1023 |
+
choices=[
|
1024 |
+
("Auto : pick sage2 > sage > sdpa depending on what is installed", "auto"),
|
1025 |
+
("Scale Dot Product Attention: default, always available", "sdpa"),
|
1026 |
+
("Flash" + check("flash")+ ": good quality - requires additional install (usually complex to set up on Windows without WSL)", "flash"),
|
1027 |
+
# ("Xformers" + check("xformers")+ ": good quality - requires additional install (usually complex, may consume less VRAM to set up on Windows without WSL)", "xformers"),
|
1028 |
+
("Sage" + check("sage")+ ": 30% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage"),
|
1029 |
+
("Sage2" + check("sage2")+ ": 40% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage2"),
|
1030 |
+
],
|
1031 |
+
value= attention_mode,
|
1032 |
+
label="Attention Type",
|
1033 |
+
interactive= not lock_ui_attention
|
1034 |
+
)
|
1035 |
+
gr.Markdown("Beware: when restarting the server or changing a resolution or video duration, the first step of generation for a duration / resolution may last a few minutes due to recompilation")
|
1036 |
+
compile_choice = gr.Dropdown(
|
1037 |
+
choices=[
|
1038 |
+
("ON: works only on Linux / WSL", "transformer"),
|
1039 |
+
("OFF: no other choice if you have Windows without using WSL", "" ),
|
1040 |
+
],
|
1041 |
+
value= compile,
|
1042 |
+
label="Compile Transformer (up to 50% faster and 30% more frames but requires Linux / WSL and Flash or Sage attention)",
|
1043 |
+
interactive= not lock_ui_compile
|
1044 |
+
)
|
1045 |
+
|
1046 |
+
|
1047 |
+
vae_config_choice = gr.Dropdown(
|
1048 |
+
choices=[
|
1049 |
+
("Auto", 0),
|
1050 |
+
("Disabled (faster but may require up to 24 GB of VRAM)", 1),
|
1051 |
+
("Enabled (2x slower and up to 50% VRAM reduction)", 2),
|
1052 |
+
],
|
1053 |
+
value= vae_config,
|
1054 |
+
label="VAE optimisations - reduce the VRAM requirements for VAE decoding and VAE encoding"
|
1055 |
+
)
|
1056 |
+
|
1057 |
+
profile_choice = gr.Dropdown(
|
1058 |
+
choices=[
|
1059 |
+
("HighRAM_HighVRAM, profile 1: at least 48 GB of RAM and 24 GB of VRAM, the fastest for short videos a RTX 3090 / RTX 4090", 1),
|
1060 |
+
("HighRAM_LowVRAM, profile 2 (Recommended): at least 48 GB of RAM and 12 GB of VRAM, the most versatile profile with high RAM, better suited for RTX 3070/3080/4070/4080 or for RTX 3090 / RTX 4090 with large pictures batches or long videos", 2),
|
1061 |
+
("LowRAM_HighVRAM, profile 3: at least 32 GB of RAM and 24 GB of VRAM, adapted for RTX 3090 / RTX 4090 with limited RAM for good speed short video",3),
|
1062 |
+
("LowRAM_LowVRAM, profile 4 (Default): at least 32 GB of RAM and 12 GB of VRAM, if you have little VRAM or want to generate longer videos",4),
|
1063 |
+
("VerylowRAM_LowVRAM, profile 5: (Fail safe): at least 16 GB of RAM and 10 GB of VRAM, if you don't have much it won't be fast but maybe it will work",5)
|
1064 |
+
],
|
1065 |
+
value= profile,
|
1066 |
+
label="Profile (for power users only, not needed to change it)"
|
1067 |
+
)
|
1068 |
+
|
1069 |
+
default_ui_choice = gr.Dropdown(
|
1070 |
+
choices=[
|
1071 |
+
("Text to Video", "t2v"),
|
1072 |
+
("Image to Video", "i2v"),
|
1073 |
+
],
|
1074 |
+
value= default_ui,
|
1075 |
+
label="Default mode when launching the App if not '--t2v' ot '--i2v' switch is specified when launching the server ",
|
1076 |
+
# visible= True ############
|
1077 |
+
)
|
1078 |
+
|
1079 |
+
msg = gr.Markdown()
|
1080 |
+
apply_btn = gr.Button("Apply Changes")
|
1081 |
+
|
1082 |
+
|
1083 |
+
with gr.Row():
|
1084 |
+
with gr.Column():
|
1085 |
+
video_to_continue = gr.Video(label= "Video to continue", visible= use_image2video and False) #######
|
1086 |
+
image_to_continue = gr.Image(label= "Image as a starting point for a new video", visible=use_image2video)
|
1087 |
+
|
1088 |
+
if use_image2video:
|
1089 |
+
prompt = gr.Textbox(label="Prompts (multiple prompts separated by carriage returns will generate multiple videos)", value="Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field.", lines=3)
|
1090 |
+
else:
|
1091 |
+
prompt = gr.Textbox(label="Prompts (multiple prompts separated by carriage returns will generate multiple videos)", value="A large orange octopus is seen resting on the bottom of the ocean floor, blending in with the sandy and rocky terrain. Its tentacles are spread out around its body, and its eyes are closed. The octopus is unaware of a king crab that is crawling towards it from behind a rock, its claws raised and ready to attack. The crab is brown and spiny, with long legs and antennae. The scene is captured from a wide angle, showing the vastness and depth of the ocean. The water is clear and blue, with rays of sunlight filtering through. The shot is sharp and crisp, with a high dynamic range. The octopus and the crab are in focus, while the background is slightly blurred, creating a depth of field effect.", lines=3)
|
1092 |
+
|
1093 |
+
|
1094 |
+
with gr.Row():
|
1095 |
+
resolution = gr.Dropdown(
|
1096 |
+
choices=[
|
1097 |
+
# 720p
|
1098 |
+
("1280x720 (16:9, 720p)", "1280x720"),
|
1099 |
+
("720x1280 (9:16, 720p)", "720x1280"),
|
1100 |
+
("1024x1024 (4:3, 720p, T2V only)", "1024x024"),
|
1101 |
+
# ("832x1104 (3:4, 720p)", "832x1104"),
|
1102 |
+
# ("960x960 (1:1, 720p)", "960x960"),
|
1103 |
+
# 480p
|
1104 |
+
# ("960x544 (16:9, 480p)", "960x544"),
|
1105 |
+
("832x480 (16:9, 480p)", "832x480"),
|
1106 |
+
("480x832 (9:16, 480p)", "480x832"),
|
1107 |
+
# ("832x624 (4:3, 540p)", "832x624"),
|
1108 |
+
# ("624x832 (3:4, 540p)", "624x832"),
|
1109 |
+
# ("720x720 (1:1, 540p)", "720x720"),
|
1110 |
+
],
|
1111 |
+
value="832x480",
|
1112 |
+
label="Resolution"
|
1113 |
+
)
|
1114 |
+
|
1115 |
+
with gr.Row():
|
1116 |
+
with gr.Column():
|
1117 |
+
video_length = gr.Slider(5, 169, value=81, step=4, label="Number of frames (16 = 1s)")
|
1118 |
+
with gr.Column():
|
1119 |
+
num_inference_steps = gr.Slider(1, 100, value= default_inference_steps, step=1, label="Number of Inference Steps")
|
1120 |
+
|
1121 |
+
with gr.Row():
|
1122 |
+
max_frames = gr.Slider(1, 100, value=9, step=1, label="Number of input frames to use for Video2World prediction", visible=use_image2video and False) #########
|
1123 |
+
|
1124 |
+
|
1125 |
+
with gr.Row(visible= len(loras)>0):
|
1126 |
+
lset_choices = [ (preset, preset) for preset in loras_presets ] + [(new_preset_msg, "")]
|
1127 |
+
with gr.Column(scale=5):
|
1128 |
+
lset_name = gr.Dropdown(show_label=False, allow_custom_value= True, scale=5, filterable=False, choices= lset_choices, value=default_lora_preset)
|
1129 |
+
with gr.Column(scale=1):
|
1130 |
+
# with gr.Column():
|
1131 |
+
with gr.Row(height=17):
|
1132 |
+
apply_lset_btn = gr.Button("Apply Lora Preset", size="sm", min_width= 1)
|
1133 |
+
with gr.Row(height=17):
|
1134 |
+
save_lset_btn = gr.Button("Save", size="sm", min_width= 1)
|
1135 |
+
delete_lset_btn = gr.Button("Delete", size="sm", min_width= 1)
|
1136 |
+
|
1137 |
+
|
1138 |
+
loras_choices = gr.Dropdown(
|
1139 |
+
choices=[
|
1140 |
+
(lora_name, str(i) ) for i, lora_name in enumerate(loras_names)
|
1141 |
+
],
|
1142 |
+
value= default_loras_choices,
|
1143 |
+
multiselect= True,
|
1144 |
+
visible= len(loras)>0,
|
1145 |
+
label="Activated Loras"
|
1146 |
+
)
|
1147 |
+
loras_mult_choices = gr.Textbox(label="Loras Multipliers (1.0 by default) separated by space characters or carriage returns", value=default_loras_multis_str, visible= len(loras)>0 )
|
1148 |
+
|
1149 |
+
show_advanced = gr.Checkbox(label="Show Advanced Options", value=False)
|
1150 |
+
with gr.Row(visible=False) as advanced_row:
|
1151 |
+
with gr.Column():
|
1152 |
+
seed = gr.Slider(-1, 999999999, value=-1, step=1, label="Seed (-1 for random)")
|
1153 |
+
repeat_generation = gr.Slider(1, 25.0, value=1.0, step=1, label="Number of Generated Video per prompt")
|
1154 |
+
with gr.Row():
|
1155 |
+
negative_prompt = gr.Textbox(label="Negative Prompt", value="")
|
1156 |
+
with gr.Row():
|
1157 |
+
guidance_scale = gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="Guidance Scale", visible=True)
|
1158 |
+
embedded_guidance_scale = gr.Slider(1.0, 20.0, value=6.0, step=0.5, label="Embedded Guidance Scale", visible=False)
|
1159 |
+
flow_shift = gr.Slider(0.0, 25.0, value= default_flow_shift, step=0.1, label="Shift Scale")
|
1160 |
+
tea_cache_setting = gr.Dropdown(
|
1161 |
+
choices=[
|
1162 |
+
("Disabled", 0),
|
1163 |
+
("Fast (x1.6 speed up)", 0.1),
|
1164 |
+
("Faster (x2.1 speed up)", 0.15),
|
1165 |
+
],
|
1166 |
+
value=default_tea_cache,
|
1167 |
+
visible=False,
|
1168 |
+
label="Tea Cache acceleration (the faster the acceleration the higher the degradation of the quality of the video. Consumes VRAM)"
|
1169 |
+
)
|
1170 |
+
|
1171 |
+
RIFLEx_setting = gr.Dropdown(
|
1172 |
+
choices=[
|
1173 |
+
("Auto (ON if Video longer than 5s)", 0),
|
1174 |
+
("Always ON", 1),
|
1175 |
+
("Always OFF", 2),
|
1176 |
+
],
|
1177 |
+
value=0,
|
1178 |
+
label="RIFLEx positional embedding to generate long video"
|
1179 |
+
)
|
1180 |
+
|
1181 |
+
show_advanced.change(fn=lambda x: gr.Row(visible=x), inputs=[show_advanced], outputs=[advanced_row])
|
1182 |
+
|
1183 |
+
with gr.Column():
|
1184 |
+
gen_status = gr.Text(label="Status", interactive= False)
|
1185 |
+
output = gr.Gallery(
|
1186 |
+
label="Generated videos", show_label=False, elem_id="gallery"
|
1187 |
+
, columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= False)
|
1188 |
+
generate_btn = gr.Button("Generate")
|
1189 |
+
abort_btn = gr.Button("Abort")
|
1190 |
+
|
1191 |
+
save_lset_btn.click(save_lset, inputs=[lset_name, loras_choices, loras_mult_choices], outputs=[lset_name])
|
1192 |
+
delete_lset_btn.click(delete_lset, inputs=[lset_name], outputs=[lset_name])
|
1193 |
+
apply_lset_btn.click(apply_lset, inputs=[lset_name,loras_choices, loras_mult_choices], outputs=[loras_choices, loras_mult_choices])
|
1194 |
+
|
1195 |
+
gen_status.change(refresh_gallery, inputs = [state], outputs = output )
|
1196 |
+
|
1197 |
+
abort_btn.click(abort_generation,state,abort_btn )
|
1198 |
+
output.select(select_video, state, None )
|
1199 |
+
|
1200 |
+
generate_btn.click(
|
1201 |
+
fn=generate_video,
|
1202 |
+
inputs=[
|
1203 |
+
prompt,
|
1204 |
+
negative_prompt,
|
1205 |
+
resolution,
|
1206 |
+
video_length,
|
1207 |
+
seed,
|
1208 |
+
num_inference_steps,
|
1209 |
+
guidance_scale,
|
1210 |
+
flow_shift,
|
1211 |
+
embedded_guidance_scale,
|
1212 |
+
repeat_generation,
|
1213 |
+
tea_cache_setting,
|
1214 |
+
loras_choices,
|
1215 |
+
loras_mult_choices,
|
1216 |
+
image_to_continue,
|
1217 |
+
video_to_continue,
|
1218 |
+
max_frames,
|
1219 |
+
RIFLEx_setting,
|
1220 |
+
state
|
1221 |
+
],
|
1222 |
+
outputs= [gen_status] #,state
|
1223 |
+
|
1224 |
+
).then(
|
1225 |
+
finalize_gallery,
|
1226 |
+
[state],
|
1227 |
+
[output , abort_btn]
|
1228 |
+
)
|
1229 |
+
|
1230 |
+
apply_btn.click(
|
1231 |
+
fn=apply_changes,
|
1232 |
+
inputs=[
|
1233 |
+
state,
|
1234 |
+
transformer_t2v_choice,
|
1235 |
+
transformer_i2v_choice,
|
1236 |
+
text_encoder_choice,
|
1237 |
+
attention_choice,
|
1238 |
+
compile_choice,
|
1239 |
+
profile_choice,
|
1240 |
+
vae_config_choice,
|
1241 |
+
default_ui_choice,
|
1242 |
+
],
|
1243 |
+
outputs= msg
|
1244 |
+
).then(
|
1245 |
+
update_defaults,
|
1246 |
+
[state, num_inference_steps, flow_shift],
|
1247 |
+
[num_inference_steps, flow_shift, header]
|
1248 |
+
)
|
1249 |
+
|
1250 |
+
return demo
|
1251 |
+
|
1252 |
+
if __name__ == "__main__":
|
1253 |
+
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
1254 |
+
server_port = int(args.server_port)
|
1255 |
+
|
1256 |
+
if server_port == 0:
|
1257 |
+
server_port = int(os.getenv("SERVER_PORT", "7860"))
|
1258 |
+
|
1259 |
+
server_name = args.server_name
|
1260 |
+
if len(server_name) == 0:
|
1261 |
+
server_name = os.getenv("SERVER_NAME", "localhost")
|
1262 |
+
|
1263 |
+
|
1264 |
+
demo = create_demo()
|
1265 |
+
if args.open_browser:
|
1266 |
+
import webbrowser
|
1267 |
+
if server_name.startswith("http"):
|
1268 |
+
url = server_name
|
1269 |
+
else:
|
1270 |
+
url = "http://" + server_name
|
1271 |
+
webbrowser.open(url + ":" + str(server_port), new = 0, autoraise = True)
|
1272 |
+
|
1273 |
+
demo.launch(server_name=server_name, server_port=server_port)
|
1274 |
+
|
1275 |
+
|
loras/README.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Put here Loras
|
loras_i2v/README.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Put here Loras
|
requirements.txt
CHANGED
@@ -11,6 +11,9 @@ easydict
|
|
11 |
ftfy
|
12 |
dashscope
|
13 |
imageio-ffmpeg
|
14 |
-
flash_attn
|
15 |
gradio>=5.0.0
|
16 |
numpy>=1.23.5,<2
|
|
|
|
|
|
|
|
11 |
ftfy
|
12 |
dashscope
|
13 |
imageio-ffmpeg
|
14 |
+
# flash_attn
|
15 |
gradio>=5.0.0
|
16 |
numpy>=1.23.5,<2
|
17 |
+
einops
|
18 |
+
moviepy==1.0.3
|
19 |
+
mmgp==3.2.1
|
wan/image2video.py
CHANGED
@@ -39,6 +39,9 @@ class WanI2V:
|
|
39 |
use_usp=False,
|
40 |
t5_cpu=False,
|
41 |
init_on_cpu=True,
|
|
|
|
|
|
|
42 |
):
|
43 |
r"""
|
44 |
Initializes the image-to-video generation model components.
|
@@ -77,7 +80,7 @@ class WanI2V:
|
|
77 |
text_len=config.text_len,
|
78 |
dtype=config.t5_dtype,
|
79 |
device=torch.device('cpu'),
|
80 |
-
checkpoint_path=
|
81 |
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
|
82 |
shard_fn=shard_fn if t5_fsdp else None,
|
83 |
)
|
@@ -95,8 +98,10 @@ class WanI2V:
|
|
95 |
config.clip_checkpoint),
|
96 |
tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
|
97 |
|
98 |
-
logging.info(f"Creating WanModel from {
|
99 |
-
|
|
|
|
|
100 |
self.model.eval().requires_grad_(False)
|
101 |
|
102 |
if t5_fsdp or dit_fsdp or use_usp:
|
@@ -116,28 +121,30 @@ class WanI2V:
|
|
116 |
else:
|
117 |
self.sp_size = 1
|
118 |
|
119 |
-
if dist.is_initialized():
|
120 |
-
|
121 |
-
if dit_fsdp:
|
122 |
-
|
123 |
-
else:
|
124 |
-
|
125 |
-
|
126 |
|
127 |
self.sample_neg_prompt = config.sample_neg_prompt
|
128 |
|
129 |
def generate(self,
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
|
|
|
|
141 |
r"""
|
142 |
Generates video frames from input image and text prompt using diffusion process.
|
143 |
|
@@ -197,14 +204,14 @@ class WanI2V:
|
|
197 |
seed_g.manual_seed(seed)
|
198 |
noise = torch.randn(
|
199 |
16,
|
200 |
-
21,
|
201 |
lat_h,
|
202 |
lat_w,
|
203 |
dtype=torch.float32,
|
204 |
generator=seed_g,
|
205 |
device=self.device)
|
206 |
|
207 |
-
msk = torch.ones(1,
|
208 |
msk[:, 1:] = 0
|
209 |
msk = torch.concat([
|
210 |
torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]
|
@@ -218,7 +225,7 @@ class WanI2V:
|
|
218 |
|
219 |
# preprocess
|
220 |
if not self.t5_cpu:
|
221 |
-
self.text_encoder.model.to(self.device)
|
222 |
context = self.text_encoder([input_prompt], self.device)
|
223 |
context_null = self.text_encoder([n_prompt], self.device)
|
224 |
if offload_model:
|
@@ -229,20 +236,23 @@ class WanI2V:
|
|
229 |
context = [t.to(self.device) for t in context]
|
230 |
context_null = [t.to(self.device) for t in context_null]
|
231 |
|
232 |
-
self.clip.model.to(self.device)
|
233 |
clip_context = self.clip.visual([img[:, None, :, :]])
|
234 |
if offload_model:
|
235 |
self.clip.model.cpu()
|
236 |
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
])
|
|
|
|
|
|
|
246 |
y = torch.concat([msk, y])
|
247 |
|
248 |
@contextmanager
|
@@ -283,6 +293,7 @@ class WanI2V:
|
|
283 |
'clip_fea': clip_context,
|
284 |
'seq_len': max_seq_len,
|
285 |
'y': [y],
|
|
|
286 |
}
|
287 |
|
288 |
arg_null = {
|
@@ -290,30 +301,39 @@ class WanI2V:
|
|
290 |
'clip_fea': clip_context,
|
291 |
'seq_len': max_seq_len,
|
292 |
'y': [y],
|
|
|
293 |
}
|
294 |
|
295 |
if offload_model:
|
296 |
torch.cuda.empty_cache()
|
297 |
|
298 |
-
self.model.to(self.device)
|
299 |
-
|
|
|
|
|
|
|
|
|
300 |
latent_model_input = [latent.to(self.device)]
|
301 |
timestep = [t]
|
302 |
|
303 |
timestep = torch.stack(timestep).to(self.device)
|
304 |
|
305 |
noise_pred_cond = self.model(
|
306 |
-
latent_model_input, t=timestep, **arg_c)[0]
|
307 |
-
|
|
|
308 |
if offload_model:
|
309 |
torch.cuda.empty_cache()
|
310 |
noise_pred_uncond = self.model(
|
311 |
-
latent_model_input, t=timestep, **arg_null)[0]
|
312 |
-
|
|
|
|
|
313 |
if offload_model:
|
314 |
torch.cuda.empty_cache()
|
315 |
noise_pred = noise_pred_uncond + guide_scale * (
|
316 |
noise_pred_cond - noise_pred_uncond)
|
|
|
317 |
|
318 |
latent = latent.to(
|
319 |
torch.device('cpu') if offload_model else self.device)
|
@@ -325,9 +345,14 @@ class WanI2V:
|
|
325 |
return_dict=False,
|
326 |
generator=seed_g)[0]
|
327 |
latent = temp_x0.squeeze(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
328 |
|
329 |
-
|
330 |
-
del latent_model_input, timestep
|
331 |
|
332 |
if offload_model:
|
333 |
self.model.cpu()
|
|
|
39 |
use_usp=False,
|
40 |
t5_cpu=False,
|
41 |
init_on_cpu=True,
|
42 |
+
i2v720p= True,
|
43 |
+
model_filename ="",
|
44 |
+
text_encoder_filename="",
|
45 |
):
|
46 |
r"""
|
47 |
Initializes the image-to-video generation model components.
|
|
|
80 |
text_len=config.text_len,
|
81 |
dtype=config.t5_dtype,
|
82 |
device=torch.device('cpu'),
|
83 |
+
checkpoint_path=text_encoder_filename,
|
84 |
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
|
85 |
shard_fn=shard_fn if t5_fsdp else None,
|
86 |
)
|
|
|
98 |
config.clip_checkpoint),
|
99 |
tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
|
100 |
|
101 |
+
logging.info(f"Creating WanModel from {model_filename}")
|
102 |
+
from mmgp import offload
|
103 |
+
|
104 |
+
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel)
|
105 |
self.model.eval().requires_grad_(False)
|
106 |
|
107 |
if t5_fsdp or dit_fsdp or use_usp:
|
|
|
121 |
else:
|
122 |
self.sp_size = 1
|
123 |
|
124 |
+
# if dist.is_initialized():
|
125 |
+
# dist.barrier()
|
126 |
+
# if dit_fsdp:
|
127 |
+
# self.model = shard_fn(self.model)
|
128 |
+
# else:
|
129 |
+
# if not init_on_cpu:
|
130 |
+
# self.model.to(self.device)
|
131 |
|
132 |
self.sample_neg_prompt = config.sample_neg_prompt
|
133 |
|
134 |
def generate(self,
|
135 |
+
input_prompt,
|
136 |
+
img,
|
137 |
+
max_area=720 * 1280,
|
138 |
+
frame_num=81,
|
139 |
+
shift=5.0,
|
140 |
+
sample_solver='unipc',
|
141 |
+
sampling_steps=40,
|
142 |
+
guide_scale=5.0,
|
143 |
+
n_prompt="",
|
144 |
+
seed=-1,
|
145 |
+
offload_model=True,
|
146 |
+
callback = None
|
147 |
+
):
|
148 |
r"""
|
149 |
Generates video frames from input image and text prompt using diffusion process.
|
150 |
|
|
|
204 |
seed_g.manual_seed(seed)
|
205 |
noise = torch.randn(
|
206 |
16,
|
207 |
+
int((frame_num - 1)/4 + 1), #21,
|
208 |
lat_h,
|
209 |
lat_w,
|
210 |
dtype=torch.float32,
|
211 |
generator=seed_g,
|
212 |
device=self.device)
|
213 |
|
214 |
+
msk = torch.ones(1, frame_num, lat_h, lat_w, device=self.device)
|
215 |
msk[:, 1:] = 0
|
216 |
msk = torch.concat([
|
217 |
torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]
|
|
|
225 |
|
226 |
# preprocess
|
227 |
if not self.t5_cpu:
|
228 |
+
# self.text_encoder.model.to(self.device)
|
229 |
context = self.text_encoder([input_prompt], self.device)
|
230 |
context_null = self.text_encoder([n_prompt], self.device)
|
231 |
if offload_model:
|
|
|
236 |
context = [t.to(self.device) for t in context]
|
237 |
context_null = [t.to(self.device) for t in context_null]
|
238 |
|
239 |
+
# self.clip.model.to(self.device)
|
240 |
clip_context = self.clip.visual([img[:, None, :, :]])
|
241 |
if offload_model:
|
242 |
self.clip.model.cpu()
|
243 |
|
244 |
+
from mmgp import offload
|
245 |
+
|
246 |
+
offload.last_offload_obj.unload_all()
|
247 |
+
enc= torch.concat([
|
248 |
+
torch.nn.functional.interpolate(
|
249 |
+
img[None].cpu(), size=(h, w), mode='bicubic').transpose(
|
250 |
+
0, 1).to(torch.bfloat16),
|
251 |
+
torch.zeros(3, frame_num-1, h, w, device="cpu", dtype= torch.bfloat16)
|
252 |
+
], dim=1).to(self.device)
|
253 |
+
# enc = None
|
254 |
+
|
255 |
+
y = self.vae.encode([enc])[0]
|
256 |
y = torch.concat([msk, y])
|
257 |
|
258 |
@contextmanager
|
|
|
293 |
'clip_fea': clip_context,
|
294 |
'seq_len': max_seq_len,
|
295 |
'y': [y],
|
296 |
+
'pipeline' : self
|
297 |
}
|
298 |
|
299 |
arg_null = {
|
|
|
301 |
'clip_fea': clip_context,
|
302 |
'seq_len': max_seq_len,
|
303 |
'y': [y],
|
304 |
+
'pipeline' : self
|
305 |
}
|
306 |
|
307 |
if offload_model:
|
308 |
torch.cuda.empty_cache()
|
309 |
|
310 |
+
# self.model.to(self.device)
|
311 |
+
if callback != None:
|
312 |
+
callback(-1, None)
|
313 |
+
|
314 |
+
self._interrupt = False
|
315 |
+
for i, t in enumerate(tqdm(timesteps)):
|
316 |
latent_model_input = [latent.to(self.device)]
|
317 |
timestep = [t]
|
318 |
|
319 |
timestep = torch.stack(timestep).to(self.device)
|
320 |
|
321 |
noise_pred_cond = self.model(
|
322 |
+
latent_model_input, t=timestep, **arg_c)[0]
|
323 |
+
if self._interrupt:
|
324 |
+
return None
|
325 |
if offload_model:
|
326 |
torch.cuda.empty_cache()
|
327 |
noise_pred_uncond = self.model(
|
328 |
+
latent_model_input, t=timestep, **arg_null)[0]
|
329 |
+
if self._interrupt:
|
330 |
+
return None
|
331 |
+
del latent_model_input
|
332 |
if offload_model:
|
333 |
torch.cuda.empty_cache()
|
334 |
noise_pred = noise_pred_uncond + guide_scale * (
|
335 |
noise_pred_cond - noise_pred_uncond)
|
336 |
+
del noise_pred_uncond
|
337 |
|
338 |
latent = latent.to(
|
339 |
torch.device('cpu') if offload_model else self.device)
|
|
|
345 |
return_dict=False,
|
346 |
generator=seed_g)[0]
|
347 |
latent = temp_x0.squeeze(0)
|
348 |
+
del temp_x0
|
349 |
+
del timestep
|
350 |
+
|
351 |
+
if callback is not None:
|
352 |
+
callback(i, latent)
|
353 |
+
|
354 |
|
355 |
+
x0 = [latent.to(self.device)]
|
|
|
356 |
|
357 |
if offload_model:
|
358 |
self.model.cpu()
|
wan/modules/__init__.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from .attention import
|
2 |
from .model import WanModel
|
3 |
from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model
|
4 |
from .tokenizers import HuggingfaceTokenizer
|
@@ -12,5 +12,5 @@ __all__ = [
|
|
12 |
'T5Decoder',
|
13 |
'T5EncoderModel',
|
14 |
'HuggingfaceTokenizer',
|
15 |
-
'
|
16 |
]
|
|
|
1 |
+
from .attention import pay_attention
|
2 |
from .model import WanModel
|
3 |
from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model
|
4 |
from .tokenizers import HuggingfaceTokenizer
|
|
|
12 |
'T5Decoder',
|
13 |
'T5EncoderModel',
|
14 |
'HuggingfaceTokenizer',
|
15 |
+
'pay_attention',
|
16 |
]
|
wan/modules/attention.py
CHANGED
@@ -1,5 +1,9 @@
|
|
1 |
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
2 |
import torch
|
|
|
|
|
|
|
|
|
3 |
|
4 |
try:
|
5 |
import flash_attn_interface
|
@@ -12,19 +16,99 @@ try:
|
|
12 |
FLASH_ATTN_2_AVAILABLE = True
|
13 |
except ModuleNotFoundError:
|
14 |
FLASH_ATTN_2_AVAILABLE = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
import warnings
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
__all__ = [
|
19 |
-
'
|
20 |
'attention',
|
21 |
]
|
22 |
|
23 |
|
24 |
-
def
|
25 |
-
|
26 |
-
|
27 |
-
|
|
|
28 |
q_lens=None,
|
29 |
k_lens=None,
|
30 |
dropout_p=0.,
|
@@ -49,6 +133,10 @@ def flash_attention(
|
|
49 |
deterministic: bool. If True, slightly slower and uses more memory.
|
50 |
dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
|
51 |
"""
|
|
|
|
|
|
|
|
|
52 |
half_dtypes = (torch.float16, torch.bfloat16)
|
53 |
assert dtype in half_dtypes
|
54 |
assert q.device.type == 'cuda' and q.size(-1) <= 256
|
@@ -91,7 +179,27 @@ def flash_attention(
|
|
91 |
)
|
92 |
|
93 |
# apply attention
|
94 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
# Note: dropout_p, window_size are not supported in FA3 now.
|
96 |
x = flash_attn_interface.flash_attn_varlen_func(
|
97 |
q=q,
|
@@ -108,8 +216,7 @@ def flash_attention(
|
|
108 |
softmax_scale=softmax_scale,
|
109 |
causal=causal,
|
110 |
deterministic=deterministic)[0].unflatten(0, (b, lq))
|
111 |
-
|
112 |
-
assert FLASH_ATTN_2_AVAILABLE
|
113 |
x = flash_attn.flash_attn_varlen_func(
|
114 |
q=q,
|
115 |
k=k,
|
@@ -146,7 +253,7 @@ def attention(
|
|
146 |
fa_version=None,
|
147 |
):
|
148 |
if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
|
149 |
-
return
|
150 |
q=q,
|
151 |
k=k,
|
152 |
v=v,
|
|
|
1 |
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
2 |
import torch
|
3 |
+
from importlib.metadata import version
|
4 |
+
from mmgp import offload
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
|
8 |
try:
|
9 |
import flash_attn_interface
|
|
|
16 |
FLASH_ATTN_2_AVAILABLE = True
|
17 |
except ModuleNotFoundError:
|
18 |
FLASH_ATTN_2_AVAILABLE = False
|
19 |
+
flash_attn = None
|
20 |
+
|
21 |
+
try:
|
22 |
+
from sageattention import sageattn_varlen
|
23 |
+
def sageattn_varlen_wrapper(
|
24 |
+
q,
|
25 |
+
k,
|
26 |
+
v,
|
27 |
+
cu_seqlens_q,
|
28 |
+
cu_seqlens_kv,
|
29 |
+
max_seqlen_q,
|
30 |
+
max_seqlen_kv,
|
31 |
+
):
|
32 |
+
return sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
|
33 |
+
except ImportError:
|
34 |
+
sageattn_varlen_wrapper = None
|
35 |
+
|
36 |
|
37 |
import warnings
|
38 |
|
39 |
+
try:
|
40 |
+
from sageattention import sageattn
|
41 |
+
@torch.compiler.disable()
|
42 |
+
def sageattn_wrapper(
|
43 |
+
qkv_list,
|
44 |
+
attention_length
|
45 |
+
):
|
46 |
+
q,k, v = qkv_list
|
47 |
+
padding_length = q.shape[0] -attention_length
|
48 |
+
q = q[:attention_length, :, : ].unsqueeze(0)
|
49 |
+
k = k[:attention_length, :, : ].unsqueeze(0)
|
50 |
+
v = v[:attention_length, :, : ].unsqueeze(0)
|
51 |
+
|
52 |
+
o = sageattn(q, k, v, tensor_layout="NHD").squeeze(0)
|
53 |
+
del q, k ,v
|
54 |
+
qkv_list.clear()
|
55 |
+
|
56 |
+
if padding_length > 0:
|
57 |
+
o = torch.cat([o, torch.empty( (padding_length, *o.shape[-2:]), dtype= o.dtype, device=o.device ) ], 0)
|
58 |
+
|
59 |
+
return o
|
60 |
+
except ImportError:
|
61 |
+
sageattn = None
|
62 |
+
|
63 |
+
|
64 |
+
@torch.compiler.disable()
|
65 |
+
def sdpa_wrapper(
|
66 |
+
qkv_list,
|
67 |
+
attention_length
|
68 |
+
):
|
69 |
+
q,k, v = qkv_list
|
70 |
+
padding_length = q.shape[0] -attention_length
|
71 |
+
q = q[:attention_length, :].transpose(0,1).unsqueeze(0)
|
72 |
+
k = k[:attention_length, :].transpose(0,1).unsqueeze(0)
|
73 |
+
v = v[:attention_length, :].transpose(0,1).unsqueeze(0)
|
74 |
+
|
75 |
+
o = F.scaled_dot_product_attention(
|
76 |
+
q, k, v, attn_mask=None, is_causal=False
|
77 |
+
).squeeze(0).transpose(0,1)
|
78 |
+
del q, k ,v
|
79 |
+
qkv_list.clear()
|
80 |
+
|
81 |
+
if padding_length > 0:
|
82 |
+
o = torch.cat([o, torch.empty( (padding_length, *o.shape[-2:]), dtype= o.dtype, device=o.device ) ], 0)
|
83 |
+
|
84 |
+
return o
|
85 |
+
|
86 |
+
|
87 |
+
def get_attention_modes():
|
88 |
+
ret = ["sdpa", "auto"]
|
89 |
+
if flash_attn != None:
|
90 |
+
ret.append("flash")
|
91 |
+
# if memory_efficient_attention != None:
|
92 |
+
# ret.append("xformers")
|
93 |
+
if sageattn_varlen_wrapper != None:
|
94 |
+
ret.append("sage")
|
95 |
+
if sageattn != None and version("sageattention").startswith("2") :
|
96 |
+
ret.append("sage2")
|
97 |
+
|
98 |
+
return ret
|
99 |
+
|
100 |
+
|
101 |
__all__ = [
|
102 |
+
'pay_attention',
|
103 |
'attention',
|
104 |
]
|
105 |
|
106 |
|
107 |
+
def pay_attention(
|
108 |
+
qkv_list,
|
109 |
+
# q,
|
110 |
+
# k,
|
111 |
+
# v,
|
112 |
q_lens=None,
|
113 |
k_lens=None,
|
114 |
dropout_p=0.,
|
|
|
133 |
deterministic: bool. If True, slightly slower and uses more memory.
|
134 |
dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
|
135 |
"""
|
136 |
+
attn = offload.shared_state["_attention"]
|
137 |
+
q,k,v = qkv_list
|
138 |
+
qkv_list.clear()
|
139 |
+
|
140 |
half_dtypes = (torch.float16, torch.bfloat16)
|
141 |
assert dtype in half_dtypes
|
142 |
assert q.device.type == 'cuda' and q.size(-1) <= 256
|
|
|
179 |
)
|
180 |
|
181 |
# apply attention
|
182 |
+
if attn=="sage":
|
183 |
+
x = sageattn_varlen_wrapper(
|
184 |
+
q=q,
|
185 |
+
k=k,
|
186 |
+
v=v,
|
187 |
+
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
|
188 |
+
0, dtype=torch.int32).to(q.device, non_blocking=True),
|
189 |
+
cu_seqlens_kv=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
|
190 |
+
0, dtype=torch.int32).to(q.device, non_blocking=True),
|
191 |
+
max_seqlen_q=lq,
|
192 |
+
max_seqlen_kv=lk,
|
193 |
+
).unflatten(0, (b, lq))
|
194 |
+
elif attn=="sage2":
|
195 |
+
qkv_list = [q,k,v]
|
196 |
+
del q,k,v
|
197 |
+
x = sageattn_wrapper(qkv_list, lq).unsqueeze(0)
|
198 |
+
elif attn=="sdpa":
|
199 |
+
qkv_list = [q, k, v]
|
200 |
+
del q, k , v
|
201 |
+
x = sdpa_wrapper( qkv_list, lq).unsqueeze(0)
|
202 |
+
elif attn=="flash" and (version is None or version == 3):
|
203 |
# Note: dropout_p, window_size are not supported in FA3 now.
|
204 |
x = flash_attn_interface.flash_attn_varlen_func(
|
205 |
q=q,
|
|
|
216 |
softmax_scale=softmax_scale,
|
217 |
causal=causal,
|
218 |
deterministic=deterministic)[0].unflatten(0, (b, lq))
|
219 |
+
elif attn=="flash":
|
|
|
220 |
x = flash_attn.flash_attn_varlen_func(
|
221 |
q=q,
|
222 |
k=k,
|
|
|
253 |
fa_version=None,
|
254 |
):
|
255 |
if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
|
256 |
+
return pay_attention(
|
257 |
q=q,
|
258 |
k=k,
|
259 |
v=v,
|
wan/modules/clip.py
CHANGED
@@ -8,7 +8,7 @@ import torch.nn as nn
|
|
8 |
import torch.nn.functional as F
|
9 |
import torchvision.transforms as T
|
10 |
|
11 |
-
from .attention import
|
12 |
from .tokenizers import HuggingfaceTokenizer
|
13 |
from .xlm_roberta import XLMRoberta
|
14 |
|
@@ -82,7 +82,7 @@ class SelfAttention(nn.Module):
|
|
82 |
|
83 |
# compute attention
|
84 |
p = self.attn_dropout if self.training else 0.0
|
85 |
-
x =
|
86 |
x = x.reshape(b, s, c)
|
87 |
|
88 |
# output
|
@@ -194,7 +194,7 @@ class AttentionPool(nn.Module):
|
|
194 |
k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
|
195 |
|
196 |
# compute attention
|
197 |
-
x =
|
198 |
x = x.reshape(b, 1, c)
|
199 |
|
200 |
# output
|
@@ -441,11 +441,12 @@ def _clip(pretrained=False,
|
|
441 |
device='cpu',
|
442 |
**kwargs):
|
443 |
# init a model on device
|
|
|
444 |
with torch.device(device):
|
445 |
model = model_cls(**kwargs)
|
446 |
|
447 |
# set device
|
448 |
-
model = model.to(dtype=dtype, device=device)
|
449 |
output = (model,)
|
450 |
|
451 |
# init transforms
|
@@ -507,16 +508,19 @@ class CLIPModel:
|
|
507 |
self.tokenizer_path = tokenizer_path
|
508 |
|
509 |
# init model
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
|
|
|
|
|
|
516 |
self.model = self.model.eval().requires_grad_(False)
|
517 |
logging.info(f'loading {checkpoint_path}')
|
518 |
self.model.load_state_dict(
|
519 |
-
torch.load(checkpoint_path, map_location='cpu'))
|
520 |
|
521 |
# init tokenizer
|
522 |
self.tokenizer = HuggingfaceTokenizer(
|
|
|
8 |
import torch.nn.functional as F
|
9 |
import torchvision.transforms as T
|
10 |
|
11 |
+
from .attention import pay_attention
|
12 |
from .tokenizers import HuggingfaceTokenizer
|
13 |
from .xlm_roberta import XLMRoberta
|
14 |
|
|
|
82 |
|
83 |
# compute attention
|
84 |
p = self.attn_dropout if self.training else 0.0
|
85 |
+
x = pay_attention([q, k, v], dropout_p=p, causal=self.causal, version=2)
|
86 |
x = x.reshape(b, s, c)
|
87 |
|
88 |
# output
|
|
|
194 |
k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
|
195 |
|
196 |
# compute attention
|
197 |
+
x = pay_attention(q, k, v, version=2)
|
198 |
x = x.reshape(b, 1, c)
|
199 |
|
200 |
# output
|
|
|
441 |
device='cpu',
|
442 |
**kwargs):
|
443 |
# init a model on device
|
444 |
+
device ="cpu"
|
445 |
with torch.device(device):
|
446 |
model = model_cls(**kwargs)
|
447 |
|
448 |
# set device
|
449 |
+
# model = model.to(dtype=dtype, device=device)
|
450 |
output = (model,)
|
451 |
|
452 |
# init transforms
|
|
|
508 |
self.tokenizer_path = tokenizer_path
|
509 |
|
510 |
# init model
|
511 |
+
from accelerate import init_empty_weights
|
512 |
+
|
513 |
+
with init_empty_weights():
|
514 |
+
self.model, self.transforms = clip_xlm_roberta_vit_h_14(
|
515 |
+
pretrained=False,
|
516 |
+
return_transforms=True,
|
517 |
+
return_tokenizer=False,
|
518 |
+
dtype=dtype,
|
519 |
+
device=device)
|
520 |
self.model = self.model.eval().requires_grad_(False)
|
521 |
logging.info(f'loading {checkpoint_path}')
|
522 |
self.model.load_state_dict(
|
523 |
+
torch.load(checkpoint_path, map_location='cpu'), assign= True)
|
524 |
|
525 |
# init tokenizer
|
526 |
self.tokenizer = HuggingfaceTokenizer(
|
wan/modules/model.py
CHANGED
@@ -7,7 +7,7 @@ import torch.nn as nn
|
|
7 |
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
8 |
from diffusers.models.modeling_utils import ModelMixin
|
9 |
|
10 |
-
from .attention import
|
11 |
|
12 |
__all__ = ['WanModel']
|
13 |
|
@@ -16,7 +16,7 @@ def sinusoidal_embedding_1d(dim, position):
|
|
16 |
# preprocess
|
17 |
assert dim % 2 == 0
|
18 |
half = dim // 2
|
19 |
-
position = position.type(torch.
|
20 |
|
21 |
# calculation
|
22 |
sinusoid = torch.outer(
|
@@ -25,18 +25,47 @@ def sinusoidal_embedding_1d(dim, position):
|
|
25 |
return x
|
26 |
|
27 |
|
28 |
-
@amp.autocast(enabled=False)
|
29 |
def rope_params(max_seq_len, dim, theta=10000):
|
30 |
assert dim % 2 == 0
|
31 |
freqs = torch.outer(
|
32 |
torch.arange(max_seq_len),
|
33 |
1.0 / torch.pow(theta,
|
34 |
-
torch.arange(0, dim, 2).to(torch.
|
35 |
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
36 |
return freqs
|
37 |
|
38 |
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
def rope_apply(x, grid_sizes, freqs):
|
41 |
n, c = x.size(2), x.size(3) // 2
|
42 |
|
@@ -45,12 +74,17 @@ def rope_apply(x, grid_sizes, freqs):
|
|
45 |
|
46 |
# loop over samples
|
47 |
output = []
|
48 |
-
for i, (f, h, w) in enumerate(grid_sizes
|
49 |
seq_len = f * h * w
|
50 |
|
51 |
# precompute multipliers
|
52 |
-
x_i =
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
54 |
freqs_i = torch.cat([
|
55 |
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
56 |
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
@@ -59,12 +93,14 @@ def rope_apply(x, grid_sizes, freqs):
|
|
59 |
dim=-1).reshape(seq_len, 1, -1)
|
60 |
|
61 |
# apply rotary embedding
|
62 |
-
x_i
|
|
|
|
|
63 |
x_i = torch.cat([x_i, x[i, seq_len:]])
|
64 |
|
65 |
# append to collection
|
66 |
output.append(x_i)
|
67 |
-
return torch.stack(output)
|
68 |
|
69 |
|
70 |
class WanRMSNorm(nn.Module):
|
@@ -80,11 +116,31 @@ class WanRMSNorm(nn.Module):
|
|
80 |
Args:
|
81 |
x(Tensor): Shape [B, L, C]
|
82 |
"""
|
83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
def _norm(self, x):
|
86 |
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
class WanLayerNorm(nn.LayerNorm):
|
90 |
|
@@ -96,7 +152,13 @@ class WanLayerNorm(nn.LayerNorm):
|
|
96 |
Args:
|
97 |
x(Tensor): Shape [B, L, C]
|
98 |
"""
|
99 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
|
102 |
class WanSelfAttention(nn.Module):
|
@@ -124,7 +186,7 @@ class WanSelfAttention(nn.Module):
|
|
124 |
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
125 |
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
126 |
|
127 |
-
def forward(self,
|
128 |
r"""
|
129 |
Args:
|
130 |
x(Tensor): Shape [B, L, num_heads, C / num_heads]
|
@@ -132,24 +194,31 @@ class WanSelfAttention(nn.Module):
|
|
132 |
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
|
133 |
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
134 |
"""
|
|
|
|
|
|
|
135 |
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
136 |
|
137 |
# query, key, value function
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
|
|
|
|
|
|
|
|
|
|
151 |
window_size=self.window_size)
|
152 |
-
|
153 |
# output
|
154 |
x = x.flatten(2)
|
155 |
x = self.o(x)
|
@@ -158,22 +227,31 @@ class WanSelfAttention(nn.Module):
|
|
158 |
|
159 |
class WanT2VCrossAttention(WanSelfAttention):
|
160 |
|
161 |
-
def forward(self,
|
162 |
r"""
|
163 |
Args:
|
164 |
x(Tensor): Shape [B, L1, C]
|
165 |
context(Tensor): Shape [B, L2, C]
|
166 |
context_lens(Tensor): Shape [B]
|
167 |
"""
|
|
|
|
|
168 |
b, n, d = x.size(0), self.num_heads, self.head_dim
|
169 |
|
170 |
# compute query, key, value
|
171 |
-
q = self.
|
172 |
-
|
|
|
|
|
|
|
|
|
|
|
173 |
v = self.v(context).view(b, -1, n, d)
|
174 |
|
175 |
# compute attention
|
176 |
-
|
|
|
|
|
177 |
|
178 |
# output
|
179 |
x = x.flatten(2)
|
@@ -196,31 +274,54 @@ class WanI2VCrossAttention(WanSelfAttention):
|
|
196 |
# self.alpha = nn.Parameter(torch.zeros((1, )))
|
197 |
self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
198 |
|
199 |
-
def forward(self,
|
200 |
r"""
|
201 |
Args:
|
202 |
x(Tensor): Shape [B, L1, C]
|
203 |
context(Tensor): Shape [B, L2, C]
|
204 |
context_lens(Tensor): Shape [B]
|
205 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
context_img = context[:, :257]
|
207 |
context = context[:, 257:]
|
208 |
b, n, d = x.size(0), self.num_heads, self.head_dim
|
209 |
|
210 |
# compute query, key, value
|
211 |
-
q = self.
|
212 |
-
|
|
|
|
|
|
|
|
|
|
|
213 |
v = self.v(context).view(b, -1, n, d)
|
214 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
v_img = self.v_img(context_img).view(b, -1, n, d)
|
216 |
-
|
|
|
|
|
217 |
# compute attention
|
218 |
-
|
219 |
|
220 |
# output
|
221 |
x = x.flatten(2)
|
222 |
img_x = img_x.flatten(2)
|
223 |
-
x
|
|
|
224 |
x = self.o(x)
|
225 |
return x
|
226 |
|
@@ -289,27 +390,46 @@ class WanAttentionBlock(nn.Module):
|
|
289 |
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
|
290 |
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
291 |
"""
|
292 |
-
|
293 |
-
|
294 |
-
e = (self.modulation + e).chunk(6, dim=1)
|
295 |
-
assert e[0].dtype == torch.float32
|
296 |
-
|
297 |
# self-attention
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
313 |
return x
|
314 |
|
315 |
|
@@ -336,10 +456,13 @@ class Head(nn.Module):
|
|
336 |
x(Tensor): Shape [B, L1, C]
|
337 |
e(Tensor): Shape [B, C]
|
338 |
"""
|
339 |
-
assert e.dtype == torch.float32
|
340 |
-
|
341 |
-
|
342 |
-
|
|
|
|
|
|
|
343 |
return x
|
344 |
|
345 |
|
@@ -384,7 +507,8 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
384 |
window_size=(-1, -1),
|
385 |
qk_norm=True,
|
386 |
cross_attn_norm=True,
|
387 |
-
eps=1e-6
|
|
|
388 |
r"""
|
389 |
Initialize the diffusion model backbone.
|
390 |
|
@@ -466,7 +590,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
466 |
# buffers (don't use register_buffer otherwise dtype will be changed in to())
|
467 |
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
|
468 |
d = dim // num_heads
|
469 |
-
self.freqs = torch.cat([
|
470 |
rope_params(1024, d - 4 * (d // 6)),
|
471 |
rope_params(1024, 2 * (d // 6)),
|
472 |
rope_params(1024, 2 * (d // 6))
|
@@ -487,6 +611,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
487 |
seq_len,
|
488 |
clip_fea=None,
|
489 |
y=None,
|
|
|
490 |
):
|
491 |
r"""
|
492 |
Forward pass through the diffusion model
|
@@ -521,8 +646,11 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
521 |
|
522 |
# embeddings
|
523 |
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
|
524 |
-
grid_sizes = torch.stack(
|
525 |
-
|
|
|
|
|
|
|
526 |
x = [u.flatten(2).transpose(1, 2) for u in x]
|
527 |
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
|
528 |
assert seq_lens.max() <= seq_len
|
@@ -532,11 +660,10 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
532 |
])
|
533 |
|
534 |
# time embeddings
|
535 |
-
|
536 |
-
|
537 |
-
|
538 |
-
|
539 |
-
assert e.dtype == torch.float32 and e0.dtype == torch.float32
|
540 |
|
541 |
# context
|
542 |
context_lens = None
|
@@ -561,6 +688,9 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
561 |
context_lens=context_lens)
|
562 |
|
563 |
for block in self.blocks:
|
|
|
|
|
|
|
564 |
x = block(x, **kwargs)
|
565 |
|
566 |
# head
|
@@ -588,7 +718,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
588 |
|
589 |
c = self.out_dim
|
590 |
out = []
|
591 |
-
for u, v in zip(x, grid_sizes
|
592 |
u = u[:math.prod(v)].view(*v, *self.patch_size, c)
|
593 |
u = torch.einsum('fhwpqrc->cfphqwr', u)
|
594 |
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
|
|
|
7 |
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
8 |
from diffusers.models.modeling_utils import ModelMixin
|
9 |
|
10 |
+
from .attention import pay_attention
|
11 |
|
12 |
__all__ = ['WanModel']
|
13 |
|
|
|
16 |
# preprocess
|
17 |
assert dim % 2 == 0
|
18 |
half = dim // 2
|
19 |
+
position = position.type(torch.float32)
|
20 |
|
21 |
# calculation
|
22 |
sinusoid = torch.outer(
|
|
|
25 |
return x
|
26 |
|
27 |
|
28 |
+
# @amp.autocast(enabled=False)
|
29 |
def rope_params(max_seq_len, dim, theta=10000):
|
30 |
assert dim % 2 == 0
|
31 |
freqs = torch.outer(
|
32 |
torch.arange(max_seq_len),
|
33 |
1.0 / torch.pow(theta,
|
34 |
+
torch.arange(0, dim, 2).to(torch.float32).div(dim)))
|
35 |
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
36 |
return freqs
|
37 |
|
38 |
|
39 |
+
def rope_apply_(x, grid_sizes, freqs):
|
40 |
+
assert x.shape[0]==1
|
41 |
+
|
42 |
+
n, c = x.size(2), x.size(3) // 2
|
43 |
+
|
44 |
+
# split freqs
|
45 |
+
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
|
46 |
+
|
47 |
+
f, h, w = grid_sizes[0]
|
48 |
+
seq_len = f * h * w
|
49 |
+
x_i = x[0, :seq_len, :, :]
|
50 |
+
|
51 |
+
x_i = x_i.to(torch.float32)
|
52 |
+
x_i = x_i.reshape(seq_len, n, -1, 2)
|
53 |
+
x_i = torch.view_as_complex(x_i)
|
54 |
+
freqs_i = torch.cat([
|
55 |
+
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
56 |
+
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
57 |
+
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
58 |
+
], dim=-1)
|
59 |
+
freqs_i= freqs_i.reshape(seq_len, 1, -1)
|
60 |
+
|
61 |
+
# apply rotary embedding
|
62 |
+
x_i *= freqs_i
|
63 |
+
x_i = torch.view_as_real(x_i).flatten(2)
|
64 |
+
x[0, :seq_len, :, :] = x_i.to(torch.bfloat16)
|
65 |
+
# x_i = torch.cat([x_i, x[0, seq_len:]])
|
66 |
+
return x
|
67 |
+
|
68 |
+
# @amp.autocast(enabled=False)
|
69 |
def rope_apply(x, grid_sizes, freqs):
|
70 |
n, c = x.size(2), x.size(3) // 2
|
71 |
|
|
|
74 |
|
75 |
# loop over samples
|
76 |
output = []
|
77 |
+
for i, (f, h, w) in enumerate(grid_sizes):
|
78 |
seq_len = f * h * w
|
79 |
|
80 |
# precompute multipliers
|
81 |
+
# x_i = x[i, :seq_len]
|
82 |
+
x_i = x[i]
|
83 |
+
x_i = x_i[:seq_len, :, :]
|
84 |
+
|
85 |
+
x_i = x_i.to(torch.float32)
|
86 |
+
x_i = x_i.reshape(seq_len, n, -1, 2)
|
87 |
+
x_i = torch.view_as_complex(x_i)
|
88 |
freqs_i = torch.cat([
|
89 |
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
90 |
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
|
|
93 |
dim=-1).reshape(seq_len, 1, -1)
|
94 |
|
95 |
# apply rotary embedding
|
96 |
+
x_i *= freqs_i
|
97 |
+
x_i = torch.view_as_real(x_i).flatten(2)
|
98 |
+
x_i = x_i.to(torch.bfloat16)
|
99 |
x_i = torch.cat([x_i, x[i, seq_len:]])
|
100 |
|
101 |
# append to collection
|
102 |
output.append(x_i)
|
103 |
+
return torch.stack(output) #.float()
|
104 |
|
105 |
|
106 |
class WanRMSNorm(nn.Module):
|
|
|
116 |
Args:
|
117 |
x(Tensor): Shape [B, L, C]
|
118 |
"""
|
119 |
+
y = x.float()
|
120 |
+
y.pow_(2)
|
121 |
+
y = y.mean(dim=-1, keepdim=True)
|
122 |
+
y += self.eps
|
123 |
+
y.rsqrt_()
|
124 |
+
x *= y
|
125 |
+
x *= self.weight
|
126 |
+
return x
|
127 |
+
# return self._norm(x).type_as(x) * self.weight
|
128 |
|
129 |
def _norm(self, x):
|
130 |
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
131 |
|
132 |
+
def my_LayerNorm(norm, x):
|
133 |
+
y = x.float()
|
134 |
+
y_m = y.mean(dim=-1, keepdim=True)
|
135 |
+
y -= y_m
|
136 |
+
del y_m
|
137 |
+
y.pow_(2)
|
138 |
+
y = y.mean(dim=-1, keepdim=True)
|
139 |
+
y += norm.eps
|
140 |
+
y.rsqrt_()
|
141 |
+
x = x * y
|
142 |
+
return x
|
143 |
+
|
144 |
|
145 |
class WanLayerNorm(nn.LayerNorm):
|
146 |
|
|
|
152 |
Args:
|
153 |
x(Tensor): Shape [B, L, C]
|
154 |
"""
|
155 |
+
# return F.layer_norm(
|
156 |
+
# input, self.normalized_shape, self.weight, self.bias, self.eps
|
157 |
+
# )
|
158 |
+
y = super().forward(x)
|
159 |
+
x = y.type_as(x)
|
160 |
+
return x
|
161 |
+
# return super().forward(x).type_as(x)
|
162 |
|
163 |
|
164 |
class WanSelfAttention(nn.Module):
|
|
|
186 |
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
187 |
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
188 |
|
189 |
+
def forward(self, xlist, seq_lens, grid_sizes, freqs):
|
190 |
r"""
|
191 |
Args:
|
192 |
x(Tensor): Shape [B, L, num_heads, C / num_heads]
|
|
|
194 |
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
|
195 |
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
196 |
"""
|
197 |
+
x = xlist[0]
|
198 |
+
xlist.clear()
|
199 |
+
|
200 |
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
201 |
|
202 |
# query, key, value function
|
203 |
+
q = self.q(x)
|
204 |
+
self.norm_q(q)
|
205 |
+
q = q.view(b, s, n, d) # !!!
|
206 |
+
k = self.k(x)
|
207 |
+
self.norm_k(k)
|
208 |
+
k = k.view(b, s, n, d)
|
209 |
+
v = self.v(x).view(b, s, n, d)
|
210 |
+
del x
|
211 |
+
rope_apply_(q, grid_sizes, freqs)
|
212 |
+
rope_apply_(k, grid_sizes, freqs)
|
213 |
+
qkv_list = [q,k,v]
|
214 |
+
del q,k,v
|
215 |
+
x = pay_attention(
|
216 |
+
qkv_list,
|
217 |
+
# q=q,
|
218 |
+
# k=k,
|
219 |
+
# v=v,
|
220 |
+
# k_lens=seq_lens,
|
221 |
window_size=self.window_size)
|
|
|
222 |
# output
|
223 |
x = x.flatten(2)
|
224 |
x = self.o(x)
|
|
|
227 |
|
228 |
class WanT2VCrossAttention(WanSelfAttention):
|
229 |
|
230 |
+
def forward(self, xlist, context, context_lens):
|
231 |
r"""
|
232 |
Args:
|
233 |
x(Tensor): Shape [B, L1, C]
|
234 |
context(Tensor): Shape [B, L2, C]
|
235 |
context_lens(Tensor): Shape [B]
|
236 |
"""
|
237 |
+
x = xlist[0]
|
238 |
+
xlist.clear()
|
239 |
b, n, d = x.size(0), self.num_heads, self.head_dim
|
240 |
|
241 |
# compute query, key, value
|
242 |
+
q = self.q(x)
|
243 |
+
del x
|
244 |
+
self.norm_q(q)
|
245 |
+
q= q.view(b, -1, n, d)
|
246 |
+
k = self.k(context)
|
247 |
+
self.norm_k(k)
|
248 |
+
k = k.view(b, -1, n, d)
|
249 |
v = self.v(context).view(b, -1, n, d)
|
250 |
|
251 |
# compute attention
|
252 |
+
qvl_list=[q, k, v]
|
253 |
+
del q, k, v
|
254 |
+
x = pay_attention(qvl_list, k_lens=context_lens)
|
255 |
|
256 |
# output
|
257 |
x = x.flatten(2)
|
|
|
274 |
# self.alpha = nn.Parameter(torch.zeros((1, )))
|
275 |
self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
276 |
|
277 |
+
def forward(self, xlist, context, context_lens):
|
278 |
r"""
|
279 |
Args:
|
280 |
x(Tensor): Shape [B, L1, C]
|
281 |
context(Tensor): Shape [B, L2, C]
|
282 |
context_lens(Tensor): Shape [B]
|
283 |
"""
|
284 |
+
|
285 |
+
##### Enjoy this spagheti VRAM optimizations done by DeepBeepMeep !
|
286 |
+
# I am sure you are a nice person and as you copy this code, you will give me officially proper credits:
|
287 |
+
# Please link to https://github.com/deepbeepmeep/Wan2GP and @deepbeepmeep on twitter
|
288 |
+
|
289 |
+
x = xlist[0]
|
290 |
+
xlist.clear()
|
291 |
+
|
292 |
context_img = context[:, :257]
|
293 |
context = context[:, 257:]
|
294 |
b, n, d = x.size(0), self.num_heads, self.head_dim
|
295 |
|
296 |
# compute query, key, value
|
297 |
+
q = self.q(x)
|
298 |
+
del x
|
299 |
+
self.norm_q(q)
|
300 |
+
q= q.view(b, -1, n, d)
|
301 |
+
k = self.k(context)
|
302 |
+
self.norm_k(k)
|
303 |
+
k = k.view(b, -1, n, d)
|
304 |
v = self.v(context).view(b, -1, n, d)
|
305 |
+
|
306 |
+
qkv_list = [q, k, v]
|
307 |
+
del k,v
|
308 |
+
x = pay_attention(qkv_list, k_lens=context_lens)
|
309 |
+
|
310 |
+
k_img = self.k_img(context_img)
|
311 |
+
self.norm_k_img(k_img)
|
312 |
+
k_img = k_img.view(b, -1, n, d)
|
313 |
v_img = self.v_img(context_img).view(b, -1, n, d)
|
314 |
+
qkv_list = [q, k_img, v_img]
|
315 |
+
del q, k_img, v_img
|
316 |
+
img_x = pay_attention(qkv_list, k_lens=None)
|
317 |
# compute attention
|
318 |
+
|
319 |
|
320 |
# output
|
321 |
x = x.flatten(2)
|
322 |
img_x = img_x.flatten(2)
|
323 |
+
x += img_x
|
324 |
+
del img_x
|
325 |
x = self.o(x)
|
326 |
return x
|
327 |
|
|
|
390 |
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
|
391 |
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
392 |
"""
|
393 |
+
e = (self.modulation + e).chunk(6, dim=1)
|
394 |
+
|
|
|
|
|
|
|
395 |
# self-attention
|
396 |
+
x_mod = self.norm1(x)
|
397 |
+
x_mod *= 1 + e[1]
|
398 |
+
x_mod += e[0]
|
399 |
+
xlist = [x_mod]
|
400 |
+
del x_mod
|
401 |
+
y = self.self_attn( xlist, seq_lens, grid_sizes,freqs)
|
402 |
+
x.addcmul_(y, e[2])
|
403 |
+
del y
|
404 |
+
y = self.norm3(x)
|
405 |
+
ylist= [y]
|
406 |
+
del y
|
407 |
+
x += self.cross_attn(ylist, context, context_lens)
|
408 |
+
y = self.norm2(x)
|
409 |
+
|
410 |
+
y *= 1 + e[4]
|
411 |
+
y += e[3]
|
412 |
+
|
413 |
+
|
414 |
+
ffn = self.ffn[0]
|
415 |
+
gelu = self.ffn[1]
|
416 |
+
ffn2= self.ffn[2]
|
417 |
+
|
418 |
+
y_shape = y.shape
|
419 |
+
y = y.view(-1, y_shape[-1])
|
420 |
+
chunk_size = int(y_shape[1]/2.7)
|
421 |
+
chunks =torch.split(y, chunk_size)
|
422 |
+
for y_chunk in chunks:
|
423 |
+
mlp_chunk = ffn(y_chunk)
|
424 |
+
mlp_chunk = gelu(mlp_chunk)
|
425 |
+
y_chunk[...] = ffn2(mlp_chunk)
|
426 |
+
del mlp_chunk
|
427 |
+
y = y.view(y_shape)
|
428 |
+
|
429 |
+
x.addcmul_(y, e[5])
|
430 |
+
|
431 |
+
|
432 |
+
|
433 |
return x
|
434 |
|
435 |
|
|
|
456 |
x(Tensor): Shape [B, L1, C]
|
457 |
e(Tensor): Shape [B, C]
|
458 |
"""
|
459 |
+
# assert e.dtype == torch.float32
|
460 |
+
|
461 |
+
e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
|
462 |
+
x = self.norm(x).to(torch.bfloat16)
|
463 |
+
x *= (1 + e[1])
|
464 |
+
x += e[0]
|
465 |
+
x = self.head(x)
|
466 |
return x
|
467 |
|
468 |
|
|
|
507 |
window_size=(-1, -1),
|
508 |
qk_norm=True,
|
509 |
cross_attn_norm=True,
|
510 |
+
eps=1e-6,
|
511 |
+
):
|
512 |
r"""
|
513 |
Initialize the diffusion model backbone.
|
514 |
|
|
|
590 |
# buffers (don't use register_buffer otherwise dtype will be changed in to())
|
591 |
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
|
592 |
d = dim // num_heads
|
593 |
+
self.freqs = torch.cat([
|
594 |
rope_params(1024, d - 4 * (d // 6)),
|
595 |
rope_params(1024, 2 * (d // 6)),
|
596 |
rope_params(1024, 2 * (d // 6))
|
|
|
611 |
seq_len,
|
612 |
clip_fea=None,
|
613 |
y=None,
|
614 |
+
pipeline = None,
|
615 |
):
|
616 |
r"""
|
617 |
Forward pass through the diffusion model
|
|
|
646 |
|
647 |
# embeddings
|
648 |
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
|
649 |
+
# grid_sizes = torch.stack(
|
650 |
+
# [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
|
651 |
+
|
652 |
+
grid_sizes = [ list(u.shape[2:]) for u in x]
|
653 |
+
|
654 |
x = [u.flatten(2).transpose(1, 2) for u in x]
|
655 |
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
|
656 |
assert seq_lens.max() <= seq_len
|
|
|
660 |
])
|
661 |
|
662 |
# time embeddings
|
663 |
+
e = self.time_embedding(
|
664 |
+
sinusoidal_embedding_1d(self.freq_dim, t))
|
665 |
+
e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(torch.bfloat16)
|
666 |
+
# assert e.dtype == torch.float32 and e0.dtype == torch.float32
|
|
|
667 |
|
668 |
# context
|
669 |
context_lens = None
|
|
|
688 |
context_lens=context_lens)
|
689 |
|
690 |
for block in self.blocks:
|
691 |
+
if pipeline._interrupt:
|
692 |
+
return [None]
|
693 |
+
|
694 |
x = block(x, **kwargs)
|
695 |
|
696 |
# head
|
|
|
718 |
|
719 |
c = self.out_dim
|
720 |
out = []
|
721 |
+
for u, v in zip(x, grid_sizes):
|
722 |
u = u[:math.prod(v)].view(*v, *self.patch_size, c)
|
723 |
u = torch.einsum('fhwpqrc->cfphqwr', u)
|
724 |
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
|
wan/modules/t5.py
CHANGED
@@ -442,7 +442,7 @@ def _t5(name,
|
|
442 |
model = model_cls(**kwargs)
|
443 |
|
444 |
# set device
|
445 |
-
model = model.to(dtype=dtype, device=device)
|
446 |
|
447 |
# init tokenizer
|
448 |
if return_tokenizer:
|
@@ -486,20 +486,25 @@ class T5EncoderModel:
|
|
486 |
self.checkpoint_path = checkpoint_path
|
487 |
self.tokenizer_path = tokenizer_path
|
488 |
|
|
|
489 |
# init model
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
|
|
495 |
logging.info(f'loading {checkpoint_path}')
|
496 |
-
|
|
|
|
|
497 |
self.model = model
|
498 |
if shard_fn is not None:
|
499 |
self.model = shard_fn(self.model, sync_module_states=False)
|
500 |
else:
|
501 |
self.model.to(self.device)
|
502 |
# init tokenizer
|
|
|
503 |
self.tokenizer = HuggingfaceTokenizer(
|
504 |
name=tokenizer_path, seq_len=text_len, clean='whitespace')
|
505 |
|
|
|
442 |
model = model_cls(**kwargs)
|
443 |
|
444 |
# set device
|
445 |
+
# model = model.to(dtype=dtype, device=device)
|
446 |
|
447 |
# init tokenizer
|
448 |
if return_tokenizer:
|
|
|
486 |
self.checkpoint_path = checkpoint_path
|
487 |
self.tokenizer_path = tokenizer_path
|
488 |
|
489 |
+
from accelerate import init_empty_weights
|
490 |
# init model
|
491 |
+
with init_empty_weights():
|
492 |
+
model = umt5_xxl(
|
493 |
+
encoder_only=True,
|
494 |
+
return_tokenizer=False,
|
495 |
+
dtype=dtype,
|
496 |
+
device=device).eval().requires_grad_(False)
|
497 |
logging.info(f'loading {checkpoint_path}')
|
498 |
+
from mmgp import offload
|
499 |
+
offload.load_model_data(model,checkpoint_path )
|
500 |
+
|
501 |
self.model = model
|
502 |
if shard_fn is not None:
|
503 |
self.model = shard_fn(self.model, sync_module_states=False)
|
504 |
else:
|
505 |
self.model.to(self.device)
|
506 |
# init tokenizer
|
507 |
+
tokenizer_path= "google/umt5-xxl"
|
508 |
self.tokenizer = HuggingfaceTokenizer(
|
509 |
name=tokenizer_path, seq_len=text_len, clean='whitespace')
|
510 |
|
wan/modules/vae.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
2 |
import logging
|
3 |
-
|
4 |
import torch
|
5 |
import torch.cuda.amp as amp
|
6 |
import torch.nn as nn
|
@@ -31,9 +31,16 @@ class CausalConv3d(nn.Conv3d):
|
|
31 |
cache_x = cache_x.to(x.device)
|
32 |
x = torch.cat([cache_x, x], dim=2)
|
33 |
padding[4] -= cache_x.shape[2]
|
|
|
34 |
x = F.pad(x, padding)
|
|
|
35 |
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
|
39 |
class RMS_norm(nn.Module):
|
@@ -49,10 +56,11 @@ class RMS_norm(nn.Module):
|
|
49 |
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
|
50 |
|
51 |
def forward(self, x):
|
52 |
-
|
53 |
x, dim=(1 if self.channel_first else
|
54 |
-1)) * self.scale * self.gamma + self.bias
|
55 |
-
|
|
|
56 |
|
57 |
class Upsample(nn.Upsample):
|
58 |
|
@@ -107,11 +115,12 @@ class Resample(nn.Module):
|
|
107 |
feat_cache[idx] = 'Rep'
|
108 |
feat_idx[0] += 1
|
109 |
else:
|
110 |
-
|
111 |
-
cache_x = x[:, :, -CACHE_T:, :, :]
|
112 |
if cache_x.shape[2] < 2 and feat_cache[
|
113 |
idx] is not None and feat_cache[idx] != 'Rep':
|
114 |
# cache last frame of last two chunk
|
|
|
115 |
cache_x = torch.cat([
|
116 |
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
117 |
cache_x.device), cache_x
|
@@ -119,11 +128,14 @@ class Resample(nn.Module):
|
|
119 |
dim=2)
|
120 |
if cache_x.shape[2] < 2 and feat_cache[
|
121 |
idx] is not None and feat_cache[idx] == 'Rep':
|
|
|
122 |
cache_x = torch.cat([
|
123 |
torch.zeros_like(cache_x).to(cache_x.device),
|
124 |
cache_x
|
125 |
],
|
126 |
dim=2)
|
|
|
|
|
127 |
if feat_cache[idx] == 'Rep':
|
128 |
x = self.time_conv(x)
|
129 |
else:
|
@@ -144,7 +156,7 @@ class Resample(nn.Module):
|
|
144 |
if feat_cache is not None:
|
145 |
idx = feat_idx[0]
|
146 |
if feat_cache[idx] is None:
|
147 |
-
feat_cache[idx] = x.clone()
|
148 |
feat_idx[0] += 1
|
149 |
else:
|
150 |
|
@@ -155,7 +167,7 @@ class Resample(nn.Module):
|
|
155 |
|
156 |
x = self.time_conv(
|
157 |
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
158 |
-
feat_cache[idx] = cache_x
|
159 |
feat_idx[0] += 1
|
160 |
return x
|
161 |
|
@@ -212,11 +224,11 @@ class ResidualBlock(nn.Module):
|
|
212 |
cache_x.device), cache_x
|
213 |
],
|
214 |
dim=2)
|
215 |
-
x = layer(x, feat_cache[idx])
|
216 |
-
feat_cache[idx] = cache_x
|
217 |
feat_idx[0] += 1
|
218 |
else:
|
219 |
-
x = layer(x)
|
220 |
return x + h
|
221 |
|
222 |
|
@@ -326,12 +338,16 @@ class Encoder3d(nn.Module):
|
|
326 |
cache_x.device), cache_x
|
327 |
],
|
328 |
dim=2)
|
329 |
-
x = self.conv1(x, feat_cache[idx])
|
330 |
feat_cache[idx] = cache_x
|
|
|
331 |
feat_idx[0] += 1
|
332 |
else:
|
333 |
x = self.conv1(x)
|
334 |
|
|
|
|
|
|
|
335 |
## downsamples
|
336 |
for layer in self.downsamples:
|
337 |
if feat_cache is not None:
|
@@ -339,6 +355,8 @@ class Encoder3d(nn.Module):
|
|
339 |
else:
|
340 |
x = layer(x)
|
341 |
|
|
|
|
|
342 |
## middle
|
343 |
for layer in self.middle:
|
344 |
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
@@ -346,6 +364,8 @@ class Encoder3d(nn.Module):
|
|
346 |
else:
|
347 |
x = layer(x)
|
348 |
|
|
|
|
|
349 |
## head
|
350 |
for layer in self.head:
|
351 |
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
@@ -360,9 +380,13 @@ class Encoder3d(nn.Module):
|
|
360 |
dim=2)
|
361 |
x = layer(x, feat_cache[idx])
|
362 |
feat_cache[idx] = cache_x
|
|
|
363 |
feat_idx[0] += 1
|
364 |
else:
|
365 |
x = layer(x)
|
|
|
|
|
|
|
366 |
return x
|
367 |
|
368 |
|
@@ -433,10 +457,12 @@ class Decoder3d(nn.Module):
|
|
433 |
],
|
434 |
dim=2)
|
435 |
x = self.conv1(x, feat_cache[idx])
|
436 |
-
feat_cache[idx] = cache_x
|
|
|
437 |
feat_idx[0] += 1
|
438 |
else:
|
439 |
x = self.conv1(x)
|
|
|
440 |
|
441 |
## middle
|
442 |
for layer in self.middle:
|
@@ -456,7 +482,7 @@ class Decoder3d(nn.Module):
|
|
456 |
for layer in self.head:
|
457 |
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
458 |
idx = feat_idx[0]
|
459 |
-
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
460 |
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
461 |
# cache last frame of last two chunk
|
462 |
cache_x = torch.cat([
|
@@ -465,7 +491,8 @@ class Decoder3d(nn.Module):
|
|
465 |
],
|
466 |
dim=2)
|
467 |
x = layer(x, feat_cache[idx])
|
468 |
-
feat_cache[idx] = cache_x
|
|
|
469 |
feat_idx[0] += 1
|
470 |
else:
|
471 |
x = layer(x)
|
@@ -532,6 +559,8 @@ class WanVAE_(nn.Module):
|
|
532 |
feat_cache=self._enc_feat_map,
|
533 |
feat_idx=self._enc_conv_idx)
|
534 |
out = torch.cat([out, out_], 2)
|
|
|
|
|
535 |
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
536 |
if isinstance(scale[0], torch.Tensor):
|
537 |
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
|
|
|
1 |
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
2 |
import logging
|
3 |
+
from mmgp import offload
|
4 |
import torch
|
5 |
import torch.cuda.amp as amp
|
6 |
import torch.nn as nn
|
|
|
31 |
cache_x = cache_x.to(x.device)
|
32 |
x = torch.cat([cache_x, x], dim=2)
|
33 |
padding[4] -= cache_x.shape[2]
|
34 |
+
cache_x = None
|
35 |
x = F.pad(x, padding)
|
36 |
+
x = super().forward(x)
|
37 |
|
38 |
+
mem_threshold = offload.shared_state.get("_vae_threshold",0)
|
39 |
+
vae_config = offload.shared_state.get("_vae",1)
|
40 |
+
|
41 |
+
if vae_config == 0 and torch.cuda.memory_reserved() > mem_threshold or vae_config == 2:
|
42 |
+
torch.cuda.empty_cache()
|
43 |
+
return x
|
44 |
|
45 |
|
46 |
class RMS_norm(nn.Module):
|
|
|
56 |
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
|
57 |
|
58 |
def forward(self, x):
|
59 |
+
x = F.normalize(
|
60 |
x, dim=(1 if self.channel_first else
|
61 |
-1)) * self.scale * self.gamma + self.bias
|
62 |
+
x = x.to(torch.bfloat16)
|
63 |
+
return x
|
64 |
|
65 |
class Upsample(nn.Upsample):
|
66 |
|
|
|
115 |
feat_cache[idx] = 'Rep'
|
116 |
feat_idx[0] += 1
|
117 |
else:
|
118 |
+
clone = True
|
119 |
+
cache_x = x[:, :, -CACHE_T:, :, :]#.clone()
|
120 |
if cache_x.shape[2] < 2 and feat_cache[
|
121 |
idx] is not None and feat_cache[idx] != 'Rep':
|
122 |
# cache last frame of last two chunk
|
123 |
+
clone = False
|
124 |
cache_x = torch.cat([
|
125 |
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
126 |
cache_x.device), cache_x
|
|
|
128 |
dim=2)
|
129 |
if cache_x.shape[2] < 2 and feat_cache[
|
130 |
idx] is not None and feat_cache[idx] == 'Rep':
|
131 |
+
clone = False
|
132 |
cache_x = torch.cat([
|
133 |
torch.zeros_like(cache_x).to(cache_x.device),
|
134 |
cache_x
|
135 |
],
|
136 |
dim=2)
|
137 |
+
if clone:
|
138 |
+
cache_x = cache_x.clone()
|
139 |
if feat_cache[idx] == 'Rep':
|
140 |
x = self.time_conv(x)
|
141 |
else:
|
|
|
156 |
if feat_cache is not None:
|
157 |
idx = feat_idx[0]
|
158 |
if feat_cache[idx] is None:
|
159 |
+
feat_cache[idx] = x #.to("cpu") #x.clone() yyyy
|
160 |
feat_idx[0] += 1
|
161 |
else:
|
162 |
|
|
|
167 |
|
168 |
x = self.time_conv(
|
169 |
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
170 |
+
feat_cache[idx] = cache_x#.to("cpu") #yyyyy
|
171 |
feat_idx[0] += 1
|
172 |
return x
|
173 |
|
|
|
224 |
cache_x.device), cache_x
|
225 |
],
|
226 |
dim=2)
|
227 |
+
x = layer(x, feat_cache[idx]).to(torch.bfloat16)
|
228 |
+
feat_cache[idx] = cache_x#.to("cpu")
|
229 |
feat_idx[0] += 1
|
230 |
else:
|
231 |
+
x = layer(x).to(torch.bfloat16)
|
232 |
return x + h
|
233 |
|
234 |
|
|
|
338 |
cache_x.device), cache_x
|
339 |
],
|
340 |
dim=2)
|
341 |
+
x = self.conv1(x, feat_cache[idx]).to(torch.bfloat16)
|
342 |
feat_cache[idx] = cache_x
|
343 |
+
del cache_x
|
344 |
feat_idx[0] += 1
|
345 |
else:
|
346 |
x = self.conv1(x)
|
347 |
|
348 |
+
|
349 |
+
# torch.cuda.empty_cache()
|
350 |
+
|
351 |
## downsamples
|
352 |
for layer in self.downsamples:
|
353 |
if feat_cache is not None:
|
|
|
355 |
else:
|
356 |
x = layer(x)
|
357 |
|
358 |
+
# torch.cuda.empty_cache()
|
359 |
+
|
360 |
## middle
|
361 |
for layer in self.middle:
|
362 |
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
|
|
364 |
else:
|
365 |
x = layer(x)
|
366 |
|
367 |
+
# torch.cuda.empty_cache()
|
368 |
+
|
369 |
## head
|
370 |
for layer in self.head:
|
371 |
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
|
|
380 |
dim=2)
|
381 |
x = layer(x, feat_cache[idx])
|
382 |
feat_cache[idx] = cache_x
|
383 |
+
del cache_x
|
384 |
feat_idx[0] += 1
|
385 |
else:
|
386 |
x = layer(x)
|
387 |
+
|
388 |
+
# torch.cuda.empty_cache()
|
389 |
+
|
390 |
return x
|
391 |
|
392 |
|
|
|
457 |
],
|
458 |
dim=2)
|
459 |
x = self.conv1(x, feat_cache[idx])
|
460 |
+
feat_cache[idx] = cache_x#.to("cpu")
|
461 |
+
del cache_x
|
462 |
feat_idx[0] += 1
|
463 |
else:
|
464 |
x = self.conv1(x)
|
465 |
+
cache_x = None
|
466 |
|
467 |
## middle
|
468 |
for layer in self.middle:
|
|
|
482 |
for layer in self.head:
|
483 |
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
484 |
idx = feat_idx[0]
|
485 |
+
cache_x = x[:, :, -CACHE_T:, :, :] .clone()
|
486 |
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
487 |
# cache last frame of last two chunk
|
488 |
cache_x = torch.cat([
|
|
|
491 |
],
|
492 |
dim=2)
|
493 |
x = layer(x, feat_cache[idx])
|
494 |
+
feat_cache[idx] = cache_x#.to("cpu")
|
495 |
+
del cache_x
|
496 |
feat_idx[0] += 1
|
497 |
else:
|
498 |
x = layer(x)
|
|
|
559 |
feat_cache=self._enc_feat_map,
|
560 |
feat_idx=self._enc_conv_idx)
|
561 |
out = torch.cat([out, out_], 2)
|
562 |
+
|
563 |
+
|
564 |
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
565 |
if isinstance(scale[0], torch.Tensor):
|
566 |
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
|
wan/text2video.py
CHANGED
@@ -35,6 +35,8 @@ class WanT2V:
|
|
35 |
dit_fsdp=False,
|
36 |
use_usp=False,
|
37 |
t5_cpu=False,
|
|
|
|
|
38 |
):
|
39 |
r"""
|
40 |
Initializes the Wan text-to-video generation model components.
|
@@ -70,18 +72,26 @@ class WanT2V:
|
|
70 |
text_len=config.text_len,
|
71 |
dtype=config.t5_dtype,
|
72 |
device=torch.device('cpu'),
|
73 |
-
checkpoint_path=
|
74 |
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
|
75 |
shard_fn=shard_fn if t5_fsdp else None)
|
76 |
|
77 |
self.vae_stride = config.vae_stride
|
78 |
self.patch_size = config.patch_size
|
|
|
|
|
79 |
self.vae = WanVAE(
|
80 |
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
|
81 |
device=self.device)
|
82 |
|
83 |
-
logging.info(f"Creating WanModel from {
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
self.model.eval().requires_grad_(False)
|
86 |
|
87 |
if use_usp:
|
@@ -98,12 +108,12 @@ class WanT2V:
|
|
98 |
else:
|
99 |
self.sp_size = 1
|
100 |
|
101 |
-
if dist.is_initialized():
|
102 |
-
|
103 |
-
if dit_fsdp:
|
104 |
-
|
105 |
-
else:
|
106 |
-
|
107 |
|
108 |
self.sample_neg_prompt = config.sample_neg_prompt
|
109 |
|
@@ -117,7 +127,9 @@ class WanT2V:
|
|
117 |
guide_scale=5.0,
|
118 |
n_prompt="",
|
119 |
seed=-1,
|
120 |
-
offload_model=True
|
|
|
|
|
121 |
r"""
|
122 |
Generates video frames from text prompt using diffusion process.
|
123 |
|
@@ -168,7 +180,7 @@ class WanT2V:
|
|
168 |
seed_g.manual_seed(seed)
|
169 |
|
170 |
if not self.t5_cpu:
|
171 |
-
self.text_encoder.model.to(self.device)
|
172 |
context = self.text_encoder([input_prompt], self.device)
|
173 |
context_null = self.text_encoder([n_prompt], self.device)
|
174 |
if offload_model:
|
@@ -223,23 +235,32 @@ class WanT2V:
|
|
223 |
# sample videos
|
224 |
latents = noise
|
225 |
|
226 |
-
arg_c = {'context': context, 'seq_len': seq_len}
|
227 |
-
arg_null = {'context': context_null, 'seq_len': seq_len}
|
228 |
|
229 |
-
|
|
|
|
|
|
|
230 |
latent_model_input = latents
|
231 |
timestep = [t]
|
232 |
|
233 |
timestep = torch.stack(timestep)
|
234 |
|
235 |
-
self.model.to(self.device)
|
236 |
noise_pred_cond = self.model(
|
237 |
latent_model_input, t=timestep, **arg_c)[0]
|
|
|
|
|
238 |
noise_pred_uncond = self.model(
|
239 |
latent_model_input, t=timestep, **arg_null)[0]
|
|
|
|
|
240 |
|
|
|
241 |
noise_pred = noise_pred_uncond + guide_scale * (
|
242 |
noise_pred_cond - noise_pred_uncond)
|
|
|
243 |
|
244 |
temp_x0 = sample_scheduler.step(
|
245 |
noise_pred.unsqueeze(0),
|
@@ -248,6 +269,10 @@ class WanT2V:
|
|
248 |
return_dict=False,
|
249 |
generator=seed_g)[0]
|
250 |
latents = [temp_x0.squeeze(0)]
|
|
|
|
|
|
|
|
|
251 |
|
252 |
x0 = latents
|
253 |
if offload_model:
|
@@ -256,6 +281,7 @@ class WanT2V:
|
|
256 |
if self.rank == 0:
|
257 |
videos = self.vae.decode(x0)
|
258 |
|
|
|
259 |
del noise, latents
|
260 |
del sample_scheduler
|
261 |
if offload_model:
|
|
|
35 |
dit_fsdp=False,
|
36 |
use_usp=False,
|
37 |
t5_cpu=False,
|
38 |
+
model_filename = None,
|
39 |
+
text_encoder_filename = None
|
40 |
):
|
41 |
r"""
|
42 |
Initializes the Wan text-to-video generation model components.
|
|
|
72 |
text_len=config.text_len,
|
73 |
dtype=config.t5_dtype,
|
74 |
device=torch.device('cpu'),
|
75 |
+
checkpoint_path=text_encoder_filename,
|
76 |
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
|
77 |
shard_fn=shard_fn if t5_fsdp else None)
|
78 |
|
79 |
self.vae_stride = config.vae_stride
|
80 |
self.patch_size = config.patch_size
|
81 |
+
|
82 |
+
|
83 |
self.vae = WanVAE(
|
84 |
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
|
85 |
device=self.device)
|
86 |
|
87 |
+
logging.info(f"Creating WanModel from {model_filename}")
|
88 |
+
from mmgp import offload
|
89 |
+
|
90 |
+
|
91 |
+
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel)
|
92 |
+
|
93 |
+
|
94 |
+
|
95 |
self.model.eval().requires_grad_(False)
|
96 |
|
97 |
if use_usp:
|
|
|
108 |
else:
|
109 |
self.sp_size = 1
|
110 |
|
111 |
+
# if dist.is_initialized():
|
112 |
+
# dist.barrier()
|
113 |
+
# if dit_fsdp:
|
114 |
+
# self.model = shard_fn(self.model)
|
115 |
+
# else:
|
116 |
+
# self.model.to(self.device)
|
117 |
|
118 |
self.sample_neg_prompt = config.sample_neg_prompt
|
119 |
|
|
|
127 |
guide_scale=5.0,
|
128 |
n_prompt="",
|
129 |
seed=-1,
|
130 |
+
offload_model=True,
|
131 |
+
callback = None
|
132 |
+
):
|
133 |
r"""
|
134 |
Generates video frames from text prompt using diffusion process.
|
135 |
|
|
|
180 |
seed_g.manual_seed(seed)
|
181 |
|
182 |
if not self.t5_cpu:
|
183 |
+
# self.text_encoder.model.to(self.device)
|
184 |
context = self.text_encoder([input_prompt], self.device)
|
185 |
context_null = self.text_encoder([n_prompt], self.device)
|
186 |
if offload_model:
|
|
|
235 |
# sample videos
|
236 |
latents = noise
|
237 |
|
238 |
+
arg_c = {'context': context, 'seq_len': seq_len, 'pipeline': self}
|
239 |
+
arg_null = {'context': context_null, 'seq_len': seq_len, 'pipeline': self}
|
240 |
|
241 |
+
if callback != None:
|
242 |
+
callback(-1, None)
|
243 |
+
self._interrupt = False
|
244 |
+
for i, t in enumerate(tqdm(timesteps)):
|
245 |
latent_model_input = latents
|
246 |
timestep = [t]
|
247 |
|
248 |
timestep = torch.stack(timestep)
|
249 |
|
250 |
+
# self.model.to(self.device)
|
251 |
noise_pred_cond = self.model(
|
252 |
latent_model_input, t=timestep, **arg_c)[0]
|
253 |
+
if self._interrupt:
|
254 |
+
return None
|
255 |
noise_pred_uncond = self.model(
|
256 |
latent_model_input, t=timestep, **arg_null)[0]
|
257 |
+
if self._interrupt:
|
258 |
+
return None
|
259 |
|
260 |
+
del latent_model_input
|
261 |
noise_pred = noise_pred_uncond + guide_scale * (
|
262 |
noise_pred_cond - noise_pred_uncond)
|
263 |
+
del noise_pred_uncond
|
264 |
|
265 |
temp_x0 = sample_scheduler.step(
|
266 |
noise_pred.unsqueeze(0),
|
|
|
269 |
return_dict=False,
|
270 |
generator=seed_g)[0]
|
271 |
latents = [temp_x0.squeeze(0)]
|
272 |
+
del temp_x0
|
273 |
+
|
274 |
+
if callback is not None:
|
275 |
+
callback(i, latents)
|
276 |
|
277 |
x0 = latents
|
278 |
if offload_model:
|
|
|
281 |
if self.rank == 0:
|
282 |
videos = self.vae.decode(x0)
|
283 |
|
284 |
+
|
285 |
del noise, latents
|
286 |
del sample_scheduler
|
287 |
if offload_model:
|