Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	Upload 54 files
Browse filesThis view is limited to 50 files because it contains too many changes.  
							See raw diff
- .gitattributes +11 -0
- LICENSE +201 -0
- README.md +13 -12
- app.py +435 -0
- assets/teaser.png +3 -0
- examples/amber.png +3 -0
- examples/armour.png +3 -0
- examples/art.wav +3 -0
- examples/chris.png +3 -0
- examples/dream.mp3 +3 -0
- examples/fictional.wav +3 -0
- examples/fight.wav +3 -0
- examples/jacket.png +3 -0
- examples/naomi.png +3 -0
- examples/science.wav +0 -0
- examples/vangogh.jpg +3 -0
- humo/common/__init__.py +0 -0
- humo/common/config.py +107 -0
- humo/common/distributed/__init__.py +41 -0
- humo/common/distributed/advanced.py +484 -0
- humo/common/distributed/basic.py +143 -0
- humo/common/logger.py +44 -0
- humo/configs/inference/generate.yaml +78 -0
- humo/configs/inference/generate_1_7B.yaml +76 -0
- humo/configs/models/Wan_1.3B.yaml +17 -0
- humo/configs/models/Wan_1.3B_I2V.yaml +18 -0
- humo/configs/models/Wan_14B.yaml +17 -0
- humo/configs/models/Wan_14B_I2V.yaml +18 -0
- humo/generate.py +984 -0
- humo/generate_1_7B.py +622 -0
- humo/models/audio/audio_proj.py +87 -0
- humo/models/distributed/__init__.py +0 -0
- humo/models/distributed/dit_ulysses_sequence_parallel.py +270 -0
- humo/models/distributed/fsdp.py +42 -0
- humo/models/text/encoder.py +173 -0
- humo/models/utils/fm_solvers.py +857 -0
- humo/models/utils/fm_solvers_unipc.py +800 -0
- humo/models/utils/utils.py +58 -0
- humo/models/wan_modules/__init__.py +16 -0
- humo/models/wan_modules/attention.py +256 -0
- humo/models/wan_modules/clip.py +542 -0
- humo/models/wan_modules/model.py +619 -0
- humo/models/wan_modules/model_humo.py +803 -0
- humo/models/wan_modules/t5.py +525 -0
- humo/models/wan_modules/tokenizers.py +82 -0
- humo/models/wan_modules/vae.py +666 -0
- humo/models/wan_modules/xlm_roberta.py +170 -0
- humo/utils/audio_processor_whisper.py +173 -0
- humo/utils/wav2vec.py +218 -0
- main.py +28 -0
    	
        .gitattributes
    CHANGED
    
    | @@ -33,3 +33,14 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text | |
| 33 | 
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 35 | 
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 33 | 
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 35 | 
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
| 36 | 
            +
            assets/teaser.png filter=lfs diff=lfs merge=lfs -text
         | 
| 37 | 
            +
            examples/amber.png filter=lfs diff=lfs merge=lfs -text
         | 
| 38 | 
            +
            examples/armour.png filter=lfs diff=lfs merge=lfs -text
         | 
| 39 | 
            +
            examples/art.wav filter=lfs diff=lfs merge=lfs -text
         | 
| 40 | 
            +
            examples/chris.png filter=lfs diff=lfs merge=lfs -text
         | 
| 41 | 
            +
            examples/dream.mp3 filter=lfs diff=lfs merge=lfs -text
         | 
| 42 | 
            +
            examples/fictional.wav filter=lfs diff=lfs merge=lfs -text
         | 
| 43 | 
            +
            examples/fight.wav filter=lfs diff=lfs merge=lfs -text
         | 
| 44 | 
            +
            examples/jacket.png filter=lfs diff=lfs merge=lfs -text
         | 
| 45 | 
            +
            examples/naomi.png filter=lfs diff=lfs merge=lfs -text
         | 
| 46 | 
            +
            examples/vangogh.jpg filter=lfs diff=lfs merge=lfs -text
         | 
    	
        LICENSE
    ADDED
    
    | @@ -0,0 +1,201 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
                                             Apache License
         | 
| 2 | 
            +
                                       Version 2.0, January 2004
         | 
| 3 | 
            +
                                    http://www.apache.org/licenses/
         | 
| 4 | 
            +
             | 
| 5 | 
            +
               TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
         | 
| 6 | 
            +
             | 
| 7 | 
            +
               1. Definitions.
         | 
| 8 | 
            +
             | 
| 9 | 
            +
                  "License" shall mean the terms and conditions for use, reproduction,
         | 
| 10 | 
            +
                  and distribution as defined by Sections 1 through 9 of this document.
         | 
| 11 | 
            +
             | 
| 12 | 
            +
                  "Licensor" shall mean the copyright owner or entity authorized by
         | 
| 13 | 
            +
                  the copyright owner that is granting the License.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                  "Legal Entity" shall mean the union of the acting entity and all
         | 
| 16 | 
            +
                  other entities that control, are controlled by, or are under common
         | 
| 17 | 
            +
                  control with that entity. For the purposes of this definition,
         | 
| 18 | 
            +
                  "control" means (i) the power, direct or indirect, to cause the
         | 
| 19 | 
            +
                  direction or management of such entity, whether by contract or
         | 
| 20 | 
            +
                  otherwise, or (ii) ownership of fifty percent (50%) or more of the
         | 
| 21 | 
            +
                  outstanding shares, or (iii) beneficial ownership of such entity.
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                  "You" (or "Your") shall mean an individual or Legal Entity
         | 
| 24 | 
            +
                  exercising permissions granted by this License.
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                  "Source" form shall mean the preferred form for making modifications,
         | 
| 27 | 
            +
                  including but not limited to software source code, documentation
         | 
| 28 | 
            +
                  source, and configuration files.
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                  "Object" form shall mean any form resulting from mechanical
         | 
| 31 | 
            +
                  transformation or translation of a Source form, including but
         | 
| 32 | 
            +
                  not limited to compiled object code, generated documentation,
         | 
| 33 | 
            +
                  and conversions to other media types.
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                  "Work" shall mean the work of authorship, whether in Source or
         | 
| 36 | 
            +
                  Object form, made available under the License, as indicated by a
         | 
| 37 | 
            +
                  copyright notice that is included in or attached to the work
         | 
| 38 | 
            +
                  (an example is provided in the Appendix below).
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                  "Derivative Works" shall mean any work, whether in Source or Object
         | 
| 41 | 
            +
                  form, that is based on (or derived from) the Work and for which the
         | 
| 42 | 
            +
                  editorial revisions, annotations, elaborations, or other modifications
         | 
| 43 | 
            +
                  represent, as a whole, an original work of authorship. For the purposes
         | 
| 44 | 
            +
                  of this License, Derivative Works shall not include works that remain
         | 
| 45 | 
            +
                  separable from, or merely link (or bind by name) to the interfaces of,
         | 
| 46 | 
            +
                  the Work and Derivative Works thereof.
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                  "Contribution" shall mean any work of authorship, including
         | 
| 49 | 
            +
                  the original version of the Work and any modifications or additions
         | 
| 50 | 
            +
                  to that Work or Derivative Works thereof, that is intentionally
         | 
| 51 | 
            +
                  submitted to Licensor for inclusion in the Work by the copyright owner
         | 
| 52 | 
            +
                  or by an individual or Legal Entity authorized to submit on behalf of
         | 
| 53 | 
            +
                  the copyright owner. For the purposes of this definition, "submitted"
         | 
| 54 | 
            +
                  means any form of electronic, verbal, or written communication sent
         | 
| 55 | 
            +
                  to the Licensor or its representatives, including but not limited to
         | 
| 56 | 
            +
                  communication on electronic mailing lists, source code control systems,
         | 
| 57 | 
            +
                  and issue tracking systems that are managed by, or on behalf of, the
         | 
| 58 | 
            +
                  Licensor for the purpose of discussing and improving the Work, but
         | 
| 59 | 
            +
                  excluding communication that is conspicuously marked or otherwise
         | 
| 60 | 
            +
                  designated in writing by the copyright owner as "Not a Contribution."
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                  "Contributor" shall mean Licensor and any individual or Legal Entity
         | 
| 63 | 
            +
                  on behalf of whom a Contribution has been received by Licensor and
         | 
| 64 | 
            +
                  subsequently incorporated within the Work.
         | 
| 65 | 
            +
             | 
| 66 | 
            +
               2. Grant of Copyright License. Subject to the terms and conditions of
         | 
| 67 | 
            +
                  this License, each Contributor hereby grants to You a perpetual,
         | 
| 68 | 
            +
                  worldwide, non-exclusive, no-charge, royalty-free, irrevocable
         | 
| 69 | 
            +
                  copyright license to reproduce, prepare Derivative Works of,
         | 
| 70 | 
            +
                  publicly display, publicly perform, sublicense, and distribute the
         | 
| 71 | 
            +
                  Work and such Derivative Works in Source or Object form.
         | 
| 72 | 
            +
             | 
| 73 | 
            +
               3. Grant of Patent License. Subject to the terms and conditions of
         | 
| 74 | 
            +
                  this License, each Contributor hereby grants to You a perpetual,
         | 
| 75 | 
            +
                  worldwide, non-exclusive, no-charge, royalty-free, irrevocable
         | 
| 76 | 
            +
                  (except as stated in this section) patent license to make, have made,
         | 
| 77 | 
            +
                  use, offer to sell, sell, import, and otherwise transfer the Work,
         | 
| 78 | 
            +
                  where such license applies only to those patent claims licensable
         | 
| 79 | 
            +
                  by such Contributor that are necessarily infringed by their
         | 
| 80 | 
            +
                  Contribution(s) alone or by combination of their Contribution(s)
         | 
| 81 | 
            +
                  with the Work to which such Contribution(s) was submitted. If You
         | 
| 82 | 
            +
                  institute patent litigation against any entity (including a
         | 
| 83 | 
            +
                  cross-claim or counterclaim in a lawsuit) alleging that the Work
         | 
| 84 | 
            +
                  or a Contribution incorporated within the Work constitutes direct
         | 
| 85 | 
            +
                  or contributory patent infringement, then any patent licenses
         | 
| 86 | 
            +
                  granted to You under this License for that Work shall terminate
         | 
| 87 | 
            +
                  as of the date such litigation is filed.
         | 
| 88 | 
            +
             | 
| 89 | 
            +
               4. Redistribution. You may reproduce and distribute copies of the
         | 
| 90 | 
            +
                  Work or Derivative Works thereof in any medium, with or without
         | 
| 91 | 
            +
                  modifications, and in Source or Object form, provided that You
         | 
| 92 | 
            +
                  meet the following conditions:
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                  (a) You must give any other recipients of the Work or
         | 
| 95 | 
            +
                      Derivative Works a copy of this License; and
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                  (b) You must cause any modified files to carry prominent notices
         | 
| 98 | 
            +
                      stating that You changed the files; and
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                  (c) You must retain, in the Source form of any Derivative Works
         | 
| 101 | 
            +
                      that You distribute, all copyright, patent, trademark, and
         | 
| 102 | 
            +
                      attribution notices from the Source form of the Work,
         | 
| 103 | 
            +
                      excluding those notices that do not pertain to any part of
         | 
| 104 | 
            +
                      the Derivative Works; and
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                  (d) If the Work includes a "NOTICE" text file as part of its
         | 
| 107 | 
            +
                      distribution, then any Derivative Works that You distribute must
         | 
| 108 | 
            +
                      include a readable copy of the attribution notices contained
         | 
| 109 | 
            +
                      within such NOTICE file, excluding those notices that do not
         | 
| 110 | 
            +
                      pertain to any part of the Derivative Works, in at least one
         | 
| 111 | 
            +
                      of the following places: within a NOTICE text file distributed
         | 
| 112 | 
            +
                      as part of the Derivative Works; within the Source form or
         | 
| 113 | 
            +
                      documentation, if provided along with the Derivative Works; or,
         | 
| 114 | 
            +
                      within a display generated by the Derivative Works, if and
         | 
| 115 | 
            +
                      wherever such third-party notices normally appear. The contents
         | 
| 116 | 
            +
                      of the NOTICE file are for informational purposes only and
         | 
| 117 | 
            +
                      do not modify the License. You may add Your own attribution
         | 
| 118 | 
            +
                      notices within Derivative Works that You distribute, alongside
         | 
| 119 | 
            +
                      or as an addendum to the NOTICE text from the Work, provided
         | 
| 120 | 
            +
                      that such additional attribution notices cannot be construed
         | 
| 121 | 
            +
                      as modifying the License.
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                  You may add Your own copyright statement to Your modifications and
         | 
| 124 | 
            +
                  may provide additional or different license terms and conditions
         | 
| 125 | 
            +
                  for use, reproduction, or distribution of Your modifications, or
         | 
| 126 | 
            +
                  for any such Derivative Works as a whole, provided Your use,
         | 
| 127 | 
            +
                  reproduction, and distribution of the Work otherwise complies with
         | 
| 128 | 
            +
                  the conditions stated in this License.
         | 
| 129 | 
            +
             | 
| 130 | 
            +
               5. Submission of Contributions. Unless You explicitly state otherwise,
         | 
| 131 | 
            +
                  any Contribution intentionally submitted for inclusion in the Work
         | 
| 132 | 
            +
                  by You to the Licensor shall be under the terms and conditions of
         | 
| 133 | 
            +
                  this License, without any additional terms or conditions.
         | 
| 134 | 
            +
                  Notwithstanding the above, nothing herein shall supersede or modify
         | 
| 135 | 
            +
                  the terms of any separate license agreement you may have executed
         | 
| 136 | 
            +
                  with Licensor regarding such Contributions.
         | 
| 137 | 
            +
             | 
| 138 | 
            +
               6. Trademarks. This License does not grant permission to use the trade
         | 
| 139 | 
            +
                  names, trademarks, service marks, or product names of the Licensor,
         | 
| 140 | 
            +
                  except as required for reasonable and customary use in describing the
         | 
| 141 | 
            +
                  origin of the Work and reproducing the content of the NOTICE file.
         | 
| 142 | 
            +
             | 
| 143 | 
            +
               7. Disclaimer of Warranty. Unless required by applicable law or
         | 
| 144 | 
            +
                  agreed to in writing, Licensor provides the Work (and each
         | 
| 145 | 
            +
                  Contributor provides its Contributions) on an "AS IS" BASIS,
         | 
| 146 | 
            +
                  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
         | 
| 147 | 
            +
                  implied, including, without limitation, any warranties or conditions
         | 
| 148 | 
            +
                  of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
         | 
| 149 | 
            +
                  PARTICULAR PURPOSE. You are solely responsible for determining the
         | 
| 150 | 
            +
                  appropriateness of using or redistributing the Work and assume any
         | 
| 151 | 
            +
                  risks associated with Your exercise of permissions under this License.
         | 
| 152 | 
            +
             | 
| 153 | 
            +
               8. Limitation of Liability. In no event and under no legal theory,
         | 
| 154 | 
            +
                  whether in tort (including negligence), contract, or otherwise,
         | 
| 155 | 
            +
                  unless required by applicable law (such as deliberate and grossly
         | 
| 156 | 
            +
                  negligent acts) or agreed to in writing, shall any Contributor be
         | 
| 157 | 
            +
                  liable to You for damages, including any direct, indirect, special,
         | 
| 158 | 
            +
                  incidental, or consequential damages of any character arising as a
         | 
| 159 | 
            +
                  result of this License or out of the use or inability to use the
         | 
| 160 | 
            +
                  Work (including but not limited to damages for loss of goodwill,
         | 
| 161 | 
            +
                  work stoppage, computer failure or malfunction, or any and all
         | 
| 162 | 
            +
                  other commercial damages or losses), even if such Contributor
         | 
| 163 | 
            +
                  has been advised of the possibility of such damages.
         | 
| 164 | 
            +
             | 
| 165 | 
            +
               9. Accepting Warranty or Additional Liability. While redistributing
         | 
| 166 | 
            +
                  the Work or Derivative Works thereof, You may choose to offer,
         | 
| 167 | 
            +
                  and charge a fee for, acceptance of support, warranty, indemnity,
         | 
| 168 | 
            +
                  or other liability obligations and/or rights consistent with this
         | 
| 169 | 
            +
                  License. However, in accepting such obligations, You may act only
         | 
| 170 | 
            +
                  on Your own behalf and on Your sole responsibility, not on behalf
         | 
| 171 | 
            +
                  of any other Contributor, and only if You agree to indemnify,
         | 
| 172 | 
            +
                  defend, and hold each Contributor harmless for any liability
         | 
| 173 | 
            +
                  incurred by, or claims asserted against, such Contributor by reason
         | 
| 174 | 
            +
                  of your accepting any such warranty or additional liability.
         | 
| 175 | 
            +
             | 
| 176 | 
            +
               END OF TERMS AND CONDITIONS
         | 
| 177 | 
            +
             | 
| 178 | 
            +
               APPENDIX: How to apply the Apache License to your work.
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                  To apply the Apache License to your work, attach the following
         | 
| 181 | 
            +
                  boilerplate notice, with the fields enclosed by brackets "{}"
         | 
| 182 | 
            +
                  replaced with your own identifying information. (Don't include
         | 
| 183 | 
            +
                  the brackets!)  The text should be enclosed in the appropriate
         | 
| 184 | 
            +
                  comment syntax for the file format. We also recommend that a
         | 
| 185 | 
            +
                  file or class name and description of purpose be included on the
         | 
| 186 | 
            +
                  same "printed page" as the copyright notice for easier
         | 
| 187 | 
            +
                  identification within third-party archives.
         | 
| 188 | 
            +
             | 
| 189 | 
            +
               Copyright 2025 Bytedance
         | 
| 190 | 
            +
             | 
| 191 | 
            +
               Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 192 | 
            +
               you may not use this file except in compliance with the License.
         | 
| 193 | 
            +
               You may obtain a copy of the License at
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                   http://www.apache.org/licenses/LICENSE-2.0
         | 
| 196 | 
            +
             | 
| 197 | 
            +
               Unless required by applicable law or agreed to in writing, software
         | 
| 198 | 
            +
               distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 199 | 
            +
               WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 200 | 
            +
               See the License for the specific language governing permissions and
         | 
| 201 | 
            +
               limitations under the License.
         | 
    	
        README.md
    CHANGED
    
    | @@ -1,12 +1,13 @@ | |
| 1 | 
            -
            ---
         | 
| 2 | 
            -
            title: HuMo Local
         | 
| 3 | 
            -
            emoji:  | 
| 4 | 
            -
            colorFrom:  | 
| 5 | 
            -
            colorTo:  | 
| 6 | 
            -
            sdk: gradio
         | 
| 7 | 
            -
            sdk_version: 5. | 
| 8 | 
            -
            app_file: app.py
         | 
| 9 | 
            -
            pinned: false
         | 
| 10 | 
            -
             | 
| 11 | 
            -
             | 
| 12 | 
            -
             | 
|  | 
|  | |
| 1 | 
            +
            ---
         | 
| 2 | 
            +
            title: HuMo [Local]
         | 
| 3 | 
            +
            emoji: 👩🦱
         | 
| 4 | 
            +
            colorFrom: purple
         | 
| 5 | 
            +
            colorTo: gray
         | 
| 6 | 
            +
            sdk: gradio
         | 
| 7 | 
            +
            sdk_version: 5.47.2
         | 
| 8 | 
            +
            app_file: app.py
         | 
| 9 | 
            +
            pinned: false
         | 
| 10 | 
            +
            short_description: Reference based video generation
         | 
| 11 | 
            +
            ---
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
         | 
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,435 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import spaces
         | 
| 2 | 
            +
            import gradio as gr
         | 
| 3 | 
            +
            import sys
         | 
| 4 | 
            +
            import os
         | 
| 5 | 
            +
            import subprocess
         | 
| 6 | 
            +
            import uuid
         | 
| 7 | 
            +
            import shutil
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            from huggingface_hub import snapshot_download, list_repo_files, hf_hub_download
         | 
| 12 | 
            +
            import importlib, site
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            # Re-discover all .pth/.egg-link files
         | 
| 16 | 
            +
            for sitedir in site.getsitepackages():
         | 
| 17 | 
            +
                site.addsitedir(sitedir)
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            # Clear caches so importlib will pick up new modules
         | 
| 20 | 
            +
            importlib.invalidate_caches()
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            def sh(cmd): subprocess.check_call(cmd, shell=True)
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            flash_attention_installed = False
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            try:
         | 
| 27 | 
            +
                flash_attention_wheel = hf_hub_download(
         | 
| 28 | 
            +
                        repo_id="alexnasa/flash-attn-3",
         | 
| 29 | 
            +
                        repo_type="model",
         | 
| 30 | 
            +
                        filename="128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl",
         | 
| 31 | 
            +
                    )
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                sh(f"pip install {flash_attention_wheel}")
         | 
| 34 | 
            +
                print("Attempting to download and install FlashAttention wheel...")
         | 
| 35 | 
            +
                # sh("pip install flash-attn")
         | 
| 36 | 
            +
                sh("pip install --no-build-isolation transformer_engine-2.5.0+f05f12c9-cp310-cp310-linux_x86_64.whl")
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                # tell Python to re-scan site-packages now that the egg-link exists
         | 
| 39 | 
            +
                import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.invalidate_caches()
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                flash_attention_installed = True
         | 
| 42 | 
            +
             | 
| 43 | 
            +
            except Exception as e:
         | 
| 44 | 
            +
                print(f"⚠️ Could not install FlashAttention: {e}")
         | 
| 45 | 
            +
                print("Continuing without FlashAttention...")
         | 
| 46 | 
            +
             | 
| 47 | 
            +
            try:
         | 
| 48 | 
            +
                te_wheel = hf_hub_download(
         | 
| 49 | 
            +
                        repo_id="alexnasa/transformer_engine_wheels",
         | 
| 50 | 
            +
                        repo_type="model",
         | 
| 51 | 
            +
                        filename="transformer_engine-2.5.0+f05f12c9-cp310-cp310-linux_x86_64.whl",
         | 
| 52 | 
            +
                    )
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                sh(f"pip install {te_wheel}")
         | 
| 55 | 
            +
                print("Attempting to download and install Transformer Engine wheel...")
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                # tell Python to re-scan site-packages now that the egg-link exists
         | 
| 58 | 
            +
                import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.invalidate_caches()
         | 
| 59 | 
            +
             | 
| 60 | 
            +
            except Exception as e:
         | 
| 61 | 
            +
                print(f"⚠️ Could not install Transformer Engine : {e}")
         | 
| 62 | 
            +
                print("Continuing without Transformer Engine ...")
         | 
| 63 | 
            +
             | 
| 64 | 
            +
            import torch
         | 
| 65 | 
            +
            print(f"Torch version: {torch.__version__}")
         | 
| 66 | 
            +
            print(f"FlashAttention available: {flash_attention_installed}")
         | 
| 67 | 
            +
             | 
| 68 | 
            +
            import tempfile
         | 
| 69 | 
            +
            from pathlib import Path
         | 
| 70 | 
            +
            from torch._inductor.runtime.runtime_utils import cache_dir as _inductor_cache_dir
         | 
| 71 | 
            +
            from huggingface_hub import HfApi
         | 
| 72 | 
            +
             | 
| 73 | 
            +
             | 
| 74 | 
            +
            snapshot_download(repo_id="bytedance-research/HuMo", local_dir="./weights/HuMo")
         | 
| 75 | 
            +
            snapshot_download(repo_id="Wan-AI/Wan2.1-T2V-1.3B", local_dir="./weights/Wan2.1-T2V-1.3B")
         | 
| 76 | 
            +
            snapshot_download(repo_id="openai/whisper-large-v3", local_dir="./weights/whisper-large-v3")
         | 
| 77 | 
            +
             | 
| 78 | 
            +
            os.environ["PROCESSED_RESULTS"] = f"{os.getcwd()}/proprocess_results"
         | 
| 79 | 
            +
             | 
| 80 | 
            +
            path_to_insert = "humo"
         | 
| 81 | 
            +
            if path_to_insert not in sys.path:
         | 
| 82 | 
            +
                sys.path.insert(0, path_to_insert)
         | 
| 83 | 
            +
             | 
| 84 | 
            +
            from common.config import load_config, create_object
         | 
| 85 | 
            +
             | 
| 86 | 
            +
            config = load_config(
         | 
| 87 | 
            +
                "./humo/configs/inference/generate.yaml",
         | 
| 88 | 
            +
                [
         | 
| 89 | 
            +
                    "dit.sp_size=1",
         | 
| 90 | 
            +
                    "generation.frames=97",
         | 
| 91 | 
            +
                    "generation.scale_t=5.5",
         | 
| 92 | 
            +
                    "generation.scale_a=5.0",
         | 
| 93 | 
            +
                    "generation.mode=TIA",
         | 
| 94 | 
            +
                    "generation.height=480",
         | 
| 95 | 
            +
                    "generation.width=832",
         | 
| 96 | 
            +
                ],
         | 
| 97 | 
            +
            )
         | 
| 98 | 
            +
            runner = create_object(config)
         | 
| 99 | 
            +
             | 
| 100 | 
            +
             | 
| 101 | 
            +
            os.environ.setdefault("TORCHINDUCTOR_CACHE_DIR", f"{os.getcwd()}/torchinductor_space")  # or another writable path
         | 
| 102 | 
            +
             | 
| 103 | 
            +
            def restore_inductor_cache_from_hub(repo_id: str, filename: str = "torch_compile_cache.zip",
         | 
| 104 | 
            +
                                                path_in_repo: str = "inductor_cache", repo_type: str = "model",
         | 
| 105 | 
            +
                                                hf_token: str | None = None):
         | 
| 106 | 
            +
                cache_root = Path(_inductor_cache_dir()).resolve()
         | 
| 107 | 
            +
                cache_root.mkdir(parents=True, exist_ok=True)
         | 
| 108 | 
            +
                zip_path = hf_hub_download(repo_id=repo_id, filename=f"{path_in_repo}/{filename}",
         | 
| 109 | 
            +
                                           repo_type=repo_type, token=hf_token)
         | 
| 110 | 
            +
                shutil.unpack_archive(zip_path, extract_dir=str(cache_root))
         | 
| 111 | 
            +
                print(f"✓ Restored cache into {cache_root}")
         | 
| 112 | 
            +
             | 
| 113 | 
            +
             | 
| 114 | 
            +
            # restore_inductor_cache_from_hub("alexnasa/humo-compiled")
         | 
| 115 | 
            +
             | 
| 116 | 
            +
             | 
| 117 | 
            +
            def get_duration(prompt_text, steps, image_file, audio_file_path, tea_cache_l1_thresh, max_duration, session_id):
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                return calculate_required_time(steps, max_duration)
         | 
| 120 | 
            +
             | 
| 121 | 
            +
            def calculate_required_time(steps, max_duration):
         | 
| 122 | 
            +
                
         | 
| 123 | 
            +
                warmup_s = 60
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                max_duration_duration_mapping = {
         | 
| 126 | 
            +
                    1: 8,
         | 
| 127 | 
            +
                    2: 8,
         | 
| 128 | 
            +
                    3: 11,
         | 
| 129 | 
            +
                    4: 20,
         | 
| 130 | 
            +
                    5: 30,
         | 
| 131 | 
            +
                }
         | 
| 132 | 
            +
                each_step_s = max_duration_duration_mapping[max_duration]
         | 
| 133 | 
            +
                duration_s = (each_step_s * steps) + warmup_s
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                print(f'estimated duration:{duration_s}')
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                return int(duration_s)
         | 
| 138 | 
            +
             | 
| 139 | 
            +
            def get_required_time_string(steps, max_duration):
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                duration_s = calculate_required_time(steps, max_duration)
         | 
| 142 | 
            +
                duration_m = duration_s / 60
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                return f"<center>⌚ Zero GPU Required: ~{duration_s}.0s ({duration_m:.1f} mins)</center>"
         | 
| 145 | 
            +
             | 
| 146 | 
            +
            def update_required_time(steps, max_duration):
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                return get_required_time_string(steps, max_duration)
         | 
| 149 | 
            +
             | 
| 150 | 
            +
             | 
| 151 | 
            +
            def generate_scene(prompt_text, steps, image_paths, audio_file_path, tea_cache_l1_thresh, max_duration = 2, session_id = None):
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                print(image_paths)
         | 
| 154 | 
            +
                prompt_text_check = (prompt_text or "").strip()
         | 
| 155 | 
            +
                if not prompt_text_check:
         | 
| 156 | 
            +
                    raise gr.Error("Please enter a prompt.")
         | 
| 157 | 
            +
                
         | 
| 158 | 
            +
                if not audio_file_path and not image_paths:
         | 
| 159 | 
            +
                    raise gr.Error("Please provide a reference image or a lipsync audio.")
         | 
| 160 | 
            +
                
         | 
| 161 | 
            +
                return run_pipeline(prompt_text, steps, image_paths, audio_file_path, tea_cache_l1_thresh, max_duration, session_id)
         | 
| 162 | 
            +
             | 
| 163 | 
            +
             | 
| 164 | 
            +
             | 
| 165 | 
            +
            def upload_inductor_cache_to_hub(
         | 
| 166 | 
            +
                repo_id: str,
         | 
| 167 | 
            +
                path_in_repo: str = "inductor_cache",
         | 
| 168 | 
            +
                repo_type: str = "model",   # or "dataset" if you prefer
         | 
| 169 | 
            +
                hf_token: str | None = None,
         | 
| 170 | 
            +
            ):
         | 
| 171 | 
            +
                """
         | 
| 172 | 
            +
                Zips the current TorchInductor cache and uploads it to the given repo path.
         | 
| 173 | 
            +
                Assumes the model was already run once with torch.compile() so the cache exists.
         | 
| 174 | 
            +
                """
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                cache_dir = Path(_inductor_cache_dir()).resolve()
         | 
| 177 | 
            +
                if not cache_dir.exists():
         | 
| 178 | 
            +
                    raise FileNotFoundError(f"TorchInductor cache not found at {cache_dir}. "
         | 
| 179 | 
            +
                                            "Run a compiled model once to populate it.")
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                # Create a zip archive of the entire cache directory
         | 
| 182 | 
            +
                with tempfile.TemporaryDirectory() as tmpdir:
         | 
| 183 | 
            +
                    archive_base = Path(tmpdir) / "torch_compile_cache"
         | 
| 184 | 
            +
                    archive_path = shutil.make_archive(str(archive_base), "zip", root_dir=str(cache_dir))
         | 
| 185 | 
            +
                    archive_path = Path(archive_path)
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                    # Upload to Hub
         | 
| 188 | 
            +
                    api = HfApi(token=hf_token)
         | 
| 189 | 
            +
                    api.create_repo(repo_id=repo_id, repo_type=repo_type, exist_ok=True)
         | 
| 190 | 
            +
                    # Put each artifact under path_in_repo, including a tiny metadata stamp for traceability
         | 
| 191 | 
            +
                    # Upload the zip
         | 
| 192 | 
            +
                    dest_path = f"{path_in_repo}/{archive_path.name}"
         | 
| 193 | 
            +
                    api.upload_file(
         | 
| 194 | 
            +
                        path_or_fileobj=str(archive_path),
         | 
| 195 | 
            +
                        path_in_repo=dest_path,
         | 
| 196 | 
            +
                        repo_id=repo_id,
         | 
| 197 | 
            +
                        repo_type=repo_type,
         | 
| 198 | 
            +
                    )
         | 
| 199 | 
            +
                    # Upload a small metadata file (optional but handy)
         | 
| 200 | 
            +
                    meta_txt = (
         | 
| 201 | 
            +
                        f"pytorch={torch.__version__}\n"
         | 
| 202 | 
            +
                        f"inductor_cache_dir={cache_dir}\n"
         | 
| 203 | 
            +
                        f"cuda_available={torch.cuda.is_available()}\n"
         | 
| 204 | 
            +
                        f"cuda_device={torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'cpu'}\n"
         | 
| 205 | 
            +
                    )
         | 
| 206 | 
            +
                    api.upload_file(
         | 
| 207 | 
            +
                        path_or_fileobj=meta_txt.encode(),
         | 
| 208 | 
            +
                        path_in_repo=f"{path_in_repo}/INDUCTOR_CACHE_METADATA.txt",
         | 
| 209 | 
            +
                        repo_id=repo_id,
         | 
| 210 | 
            +
                        repo_type=repo_type,
         | 
| 211 | 
            +
                    )
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                print("✔ Uploaded TorchInductor cache to the Hub.")
         | 
| 214 | 
            +
             | 
| 215 | 
            +
             | 
| 216 | 
            +
            @spaces.GPU(duration=get_duration)
         | 
| 217 | 
            +
            def run_pipeline(prompt_text, steps, image_paths, audio_file_path, tea_cache_l1_thresh = 0.0, max_duration = 2, session_id = None):
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                if session_id is None:
         | 
| 220 | 
            +
                    session_id = uuid.uuid4().hex
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                inference_mode = "TIA"
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                # Validate inputs
         | 
| 225 | 
            +
                prompt_text = (prompt_text or "").strip()
         | 
| 226 | 
            +
                if not prompt_text:
         | 
| 227 | 
            +
                    raise gr.Error("Please enter a prompt.")
         | 
| 228 | 
            +
                
         | 
| 229 | 
            +
                if not audio_file_path and not image_paths:
         | 
| 230 | 
            +
                    raise gr.Error("Please provide a reference image or a lipsync audio.")
         | 
| 231 | 
            +
                
         | 
| 232 | 
            +
                if not audio_file_path:
         | 
| 233 | 
            +
                    inference_mode = "TI"
         | 
| 234 | 
            +
                    audio_path = None
         | 
| 235 | 
            +
                else:
         | 
| 236 | 
            +
                    audio_path = audio_file_path if isinstance(audio_file_path, str) else getattr(audio_file_path, "name", str(audio_file_path))
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                if not image_paths:
         | 
| 239 | 
            +
                    inference_mode = "TA"
         | 
| 240 | 
            +
                    img_paths = None
         | 
| 241 | 
            +
                else:
         | 
| 242 | 
            +
                    img_paths = [image_data[0] for image_data in image_paths]
         | 
| 243 | 
            +
             | 
| 244 | 
            +
             | 
| 245 | 
            +
                # Prepare output
         | 
| 246 | 
            +
                output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
         | 
| 247 | 
            +
                os.makedirs(output_dir, exist_ok=True)
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                # Random filename
         | 
| 250 | 
            +
                filename = f"gen_{uuid.uuid4().hex[:10]}"
         | 
| 251 | 
            +
                width, height = 832, 480
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                duration_frame_mapping = {
         | 
| 254 | 
            +
                    1:25,
         | 
| 255 | 
            +
                    2:45,
         | 
| 256 | 
            +
                    3:70, 
         | 
| 257 | 
            +
                    4:97,
         | 
| 258 | 
            +
                    5:129
         | 
| 259 | 
            +
                }
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                # Run inference
         | 
| 262 | 
            +
                runner.inference_loop(
         | 
| 263 | 
            +
                    prompt_text,
         | 
| 264 | 
            +
                    img_paths,
         | 
| 265 | 
            +
                    audio_path,
         | 
| 266 | 
            +
                    output_dir,
         | 
| 267 | 
            +
                    filename,
         | 
| 268 | 
            +
                    inference_mode,
         | 
| 269 | 
            +
                    width,
         | 
| 270 | 
            +
                    height,
         | 
| 271 | 
            +
                    steps,
         | 
| 272 | 
            +
                    frames = int(duration_frame_mapping[max_duration]),
         | 
| 273 | 
            +
                    tea_cache_l1_thresh = tea_cache_l1_thresh,
         | 
| 274 | 
            +
                )
         | 
| 275 | 
            +
             | 
| 276 | 
            +
                # Return resulting video path
         | 
| 277 | 
            +
                video_path = os.path.join(output_dir, f"{filename}.mp4")
         | 
| 278 | 
            +
                if os.path.exists(video_path):
         | 
| 279 | 
            +
             | 
| 280 | 
            +
                    # upload_inductor_cache_to_hub("alexnasa/humo-compiled")
         | 
| 281 | 
            +
             | 
| 282 | 
            +
                    return video_path
         | 
| 283 | 
            +
                else:
         | 
| 284 | 
            +
                    candidates = [os.path.join(output_dir, f) for f in os.listdir(output_dir) if f.endswith(".mp4")]
         | 
| 285 | 
            +
                    if candidates:
         | 
| 286 | 
            +
                        return max(candidates, key=lambda p: os.path.getmtime(p))
         | 
| 287 | 
            +
                    return None
         | 
| 288 | 
            +
             | 
| 289 | 
            +
            css = """
         | 
| 290 | 
            +
                #col-container {
         | 
| 291 | 
            +
                    margin: 0 auto;
         | 
| 292 | 
            +
                    width: 100%;
         | 
| 293 | 
            +
                    max-width: 720px;
         | 
| 294 | 
            +
                }
         | 
| 295 | 
            +
                """
         | 
| 296 | 
            +
             | 
| 297 | 
            +
            def cleanup(request: gr.Request):
         | 
| 298 | 
            +
             | 
| 299 | 
            +
                sid = request.session_hash
         | 
| 300 | 
            +
                if sid:
         | 
| 301 | 
            +
                    d1 = os.path.join(os.environ["PROCESSED_RESULTS"], sid)
         | 
| 302 | 
            +
                    shutil.rmtree(d1, ignore_errors=True)
         | 
| 303 | 
            +
                    
         | 
| 304 | 
            +
            def start_session(request: gr.Request):
         | 
| 305 | 
            +
             | 
| 306 | 
            +
                return request.session_hash
         | 
| 307 | 
            +
             | 
| 308 | 
            +
            with gr.Blocks(css=css) as demo:
         | 
| 309 | 
            +
             | 
| 310 | 
            +
                session_state = gr.State()
         | 
| 311 | 
            +
                demo.load(start_session, outputs=[session_state])
         | 
| 312 | 
            +
             | 
| 313 | 
            +
                with gr.Sidebar(width=400):
         | 
| 314 | 
            +
             | 
| 315 | 
            +
             | 
| 316 | 
            +
                    gr.HTML(
         | 
| 317 | 
            +
                        """
         | 
| 318 | 
            +
                        <div style="text-align: center;">
         | 
| 319 | 
            +
                            <p style="font-size:16px; display: inline; margin: 0;">
         | 
| 320 | 
            +
                                <strong>HuMo</strong> – Human-Centric Video Generation via Collaborative Multi-Modal Conditioning
         | 
| 321 | 
            +
                            </p>
         | 
| 322 | 
            +
                            <a href="https://github.com/Phantom-video/HuMo" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
         | 
| 323 | 
            +
                                [Github]
         | 
| 324 | 
            +
                            </a>
         | 
| 325 | 
            +
                        </div>
         | 
| 326 | 
            +
                        """
         | 
| 327 | 
            +
                    )
         | 
| 328 | 
            +
             | 
| 329 | 
            +
                    gr.Markdown("**REFERENCE IMAGES**")
         | 
| 330 | 
            +
             | 
| 331 | 
            +
                    img_input = gr.Gallery(
         | 
| 332 | 
            +
                        show_label=False,
         | 
| 333 | 
            +
                        label="",
         | 
| 334 | 
            +
                        interactive=True,
         | 
| 335 | 
            +
                        rows=1, columns=3, object_fit="contain", height="280",
         | 
| 336 | 
            +
                        file_types=['image']
         | 
| 337 | 
            +
                    )
         | 
| 338 | 
            +
             | 
| 339 | 
            +
                    gr.Markdown("**LIPSYNC AUDIO**")
         | 
| 340 | 
            +
             | 
| 341 | 
            +
                    audio_input = gr.Audio(
         | 
| 342 | 
            +
                        sources=["upload"],
         | 
| 343 | 
            +
                        show_label=False,
         | 
| 344 | 
            +
                        type="filepath",
         | 
| 345 | 
            +
                    )
         | 
| 346 | 
            +
             | 
| 347 | 
            +
                    gr.Markdown("**SETTINGS**")
         | 
| 348 | 
            +
             | 
| 349 | 
            +
                    default_steps = 10
         | 
| 350 | 
            +
                    default_max_duration = 2
         | 
| 351 | 
            +
             | 
| 352 | 
            +
                    max_duration = gr.Slider(minimum=2, maximum=5, value=default_max_duration, step=1, label="Max Duration")
         | 
| 353 | 
            +
                    steps_input = gr.Slider(minimum=5, maximum=50, value=default_steps, step=5, label="Diffusion Steps")
         | 
| 354 | 
            +
                    tea_cache_l1_thresh = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.01, label="Cache", visible=False)
         | 
| 355 | 
            +
                    
         | 
| 356 | 
            +
             | 
| 357 | 
            +
             | 
| 358 | 
            +
                with gr.Column(elem_id="col-container"):
         | 
| 359 | 
            +
             | 
| 360 | 
            +
                    gr.HTML(
         | 
| 361 | 
            +
                        """
         | 
| 362 | 
            +
                        <div style="text-align: center;">
         | 
| 363 | 
            +
                            <strong>HF Space by:</strong>
         | 
| 364 | 
            +
                            <a href="https://twitter.com/alexandernasa/" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
         | 
| 365 | 
            +
                                <img src="https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=Follow Me" alt="GitHub Repo">
         | 
| 366 | 
            +
                            </a>
         | 
| 367 | 
            +
                        </div>
         | 
| 368 | 
            +
                        """
         | 
| 369 | 
            +
                    )
         | 
| 370 | 
            +
             | 
| 371 | 
            +
                    video_output = gr.Video(show_label=False)
         | 
| 372 | 
            +
             | 
| 373 | 
            +
                    gr.Markdown("<center><h2>PROMPT</h2></center>")
         | 
| 374 | 
            +
             | 
| 375 | 
            +
                    prompt_tb = gr.Textbox(
         | 
| 376 | 
            +
                        show_label=False,
         | 
| 377 | 
            +
                        lines=5,
         | 
| 378 | 
            +
                        placeholder="Describe the scene and the person talking....",
         | 
| 379 | 
            +
                    )
         | 
| 380 | 
            +
             | 
| 381 | 
            +
                    gr.Markdown("")
         | 
| 382 | 
            +
                    time_required = gr.Markdown(get_required_time_string(default_steps, default_max_duration))
         | 
| 383 | 
            +
                    run_btn = gr.Button("🎬 Action", variant="primary")
         | 
| 384 | 
            +
                    
         | 
| 385 | 
            +
                    gr.Examples(
         | 
| 386 | 
            +
                        examples=[
         | 
| 387 | 
            +
             | 
| 388 | 
            +
                            [
         | 
| 389 | 
            +
                                "A handheld tracking shot follows a female warrior walking through a cave. Her determined eyes are locked straight ahead. She speaks with intensity.",
         | 
| 390 | 
            +
                                5,
         | 
| 391 | 
            +
                                ["./examples/naomi.png"], 
         | 
| 392 | 
            +
                                "./examples/dream.mp3",                  
         | 
| 393 | 
            +
                            ],
         | 
| 394 | 
            +
             | 
| 395 | 
            +
                            [
         | 
| 396 | 
            +
                                "A reddish-brown haired and bearded man sits pensively against swirling blue-and-white brushstrokes, dressed in a blue coat and dark waistcoat. The artistic backdrop and his thoughtful pose evoke a Post-Impressionist style in a studio-like setting.",
         | 
| 397 | 
            +
                                10,
         | 
| 398 | 
            +
                                ["./examples/vangogh.jpg"], 
         | 
| 399 | 
            +
                                "./examples/art.wav",               
         | 
| 400 | 
            +
                            ],
         | 
| 401 | 
            +
             | 
| 402 | 
            +
                            [
         | 
| 403 | 
            +
                                "A handheld tracking shot follows a female through a science lab. Her determined eyes are locked straight ahead. The clip is in black and white and patchy as she is explaining something to someone standing opposite her",
         | 
| 404 | 
            +
                                10,
         | 
| 405 | 
            +
                                ["./examples/naomi.png"], 
         | 
| 406 | 
            +
                                "./examples/science.wav",               
         | 
| 407 | 
            +
                            ],
         | 
| 408 | 
            +
             | 
| 409 | 
            +
                            [
         | 
| 410 | 
            +
                                "A woman with long, wavy dark hair looking at a person sitting opposite her whilst holding a book, wearing a leather jacket, long-sleeved jacket with a semi purple color one seen on a photo. Warm, window-like light bathes her figure, highlighting the outfit's elegant design and her graceful movements.",
         | 
| 411 | 
            +
                                50,
         | 
| 412 | 
            +
                                ["./examples/amber.png", "./examples/jacket.png"],
         | 
| 413 | 
            +
                                "./examples/fictional.mp3",                  
         | 
| 414 | 
            +
                            ],
         | 
| 415 | 
            +
             | 
| 416 | 
            +
                        ],
         | 
| 417 | 
            +
                        inputs=[prompt_tb, steps_input, img_input, audio_input],
         | 
| 418 | 
            +
                        outputs=[video_output],
         | 
| 419 | 
            +
                        fn=run_pipeline,
         | 
| 420 | 
            +
                        cache_examples=True,
         | 
| 421 | 
            +
                    )
         | 
| 422 | 
            +
                    max_duration.change(update_required_time, [steps_input, max_duration], time_required)
         | 
| 423 | 
            +
                    steps_input.change(update_required_time, [steps_input, max_duration], time_required)
         | 
| 424 | 
            +
             | 
| 425 | 
            +
                    run_btn.click(
         | 
| 426 | 
            +
                        fn=generate_scene,
         | 
| 427 | 
            +
                        inputs=[prompt_tb, steps_input, img_input, audio_input, tea_cache_l1_thresh, max_duration, session_state],
         | 
| 428 | 
            +
                        outputs=[video_output],
         | 
| 429 | 
            +
                    )
         | 
| 430 | 
            +
             | 
| 431 | 
            +
             | 
| 432 | 
            +
            if __name__ == "__main__":
         | 
| 433 | 
            +
                demo.unload(cleanup)
         | 
| 434 | 
            +
                demo.queue()
         | 
| 435 | 
            +
                demo.launch(ssr_mode=False)
         | 
    	
        assets/teaser.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        examples/amber.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        examples/armour.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        examples/art.wav
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:72c75df8e93a107e262ea9b002a66e72d3c1cd2084bce1474a31d8afffd0b651
         | 
| 3 | 
            +
            size 114254
         | 
    	
        examples/chris.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        examples/dream.mp3
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:27248fd9e8f29bd60ccb1163b8df3c6f2630734f358aa3362ffe67e8148e0eb1
         | 
| 3 | 
            +
            size 108275
         | 
    	
        examples/fictional.wav
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:31b550e6433ea44a0642dee90c326664ff4f568fec184170001f834597b3ad23
         | 
| 3 | 
            +
            size 167084
         | 
    	
        examples/fight.wav
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:8dbee86c85e992ac6d17820a3730bf753fc9bf5bac6b8a470f84b7e98a64221a
         | 
| 3 | 
            +
            size 264782
         | 
    	
        examples/jacket.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        examples/naomi.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        examples/science.wav
    ADDED
    
    | Binary file (82.5 kB). View file | 
|  | 
    	
        examples/vangogh.jpg
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        humo/common/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        humo/common/config.py
    ADDED
    
    | @@ -0,0 +1,107 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
         | 
| 2 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 3 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 4 | 
            +
            # You may obtain a copy of the License at
         | 
| 5 | 
            +
            # http://www.apache.org/licenses/LICENSE-2.0
         | 
| 6 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 7 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 8 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 9 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 10 | 
            +
            # limitations under the License.
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            # Codes adapted from [SeedVR]
         | 
| 13 | 
            +
            # https://github.com/ByteDance-Seed/SeedVR/blob/main/common/config.py
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            """
         | 
| 16 | 
            +
            Configuration utility functions
         | 
| 17 | 
            +
            """
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            import importlib
         | 
| 20 | 
            +
            from typing import Any, Callable, List, Union
         | 
| 21 | 
            +
            from omegaconf import DictConfig, ListConfig, OmegaConf
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            OmegaConf.register_new_resolver("eval", eval)
         | 
| 24 | 
            +
             | 
| 25 | 
            +
             | 
| 26 | 
            +
            def load_config(path: str, argv: List[str] = None) -> Union[DictConfig, ListConfig]:
         | 
| 27 | 
            +
                """
         | 
| 28 | 
            +
                Load a configuration. Will resolve inheritance.
         | 
| 29 | 
            +
                """
         | 
| 30 | 
            +
                config = OmegaConf.load(path)
         | 
| 31 | 
            +
                if argv is not None:
         | 
| 32 | 
            +
                    config_argv = OmegaConf.from_dotlist(argv)
         | 
| 33 | 
            +
                    config = OmegaConf.merge(config, config_argv)
         | 
| 34 | 
            +
                config = resolve_recursive(config, resolve_inheritance)
         | 
| 35 | 
            +
                return config
         | 
| 36 | 
            +
             | 
| 37 | 
            +
             | 
| 38 | 
            +
            def resolve_recursive(
         | 
| 39 | 
            +
                config: Any,
         | 
| 40 | 
            +
                resolver: Callable[[Union[DictConfig, ListConfig]], Union[DictConfig, ListConfig]],
         | 
| 41 | 
            +
            ) -> Any:
         | 
| 42 | 
            +
                config = resolver(config)
         | 
| 43 | 
            +
                if isinstance(config, DictConfig):
         | 
| 44 | 
            +
                    for k in config.keys():
         | 
| 45 | 
            +
                        v = config.get(k)
         | 
| 46 | 
            +
                        if isinstance(v, (DictConfig, ListConfig)):
         | 
| 47 | 
            +
                            config[k] = resolve_recursive(v, resolver)
         | 
| 48 | 
            +
                if isinstance(config, ListConfig):
         | 
| 49 | 
            +
                    for i in range(len(config)):
         | 
| 50 | 
            +
                        v = config.get(i)
         | 
| 51 | 
            +
                        if isinstance(v, (DictConfig, ListConfig)):
         | 
| 52 | 
            +
                            config[i] = resolve_recursive(v, resolver)
         | 
| 53 | 
            +
                return config
         | 
| 54 | 
            +
             | 
| 55 | 
            +
             | 
| 56 | 
            +
            def resolve_inheritance(config: Union[DictConfig, ListConfig]) -> Any:
         | 
| 57 | 
            +
                """
         | 
| 58 | 
            +
                Recursively resolve inheritance if the config contains:
         | 
| 59 | 
            +
                __inherit__: path/to/parent.yaml.
         | 
| 60 | 
            +
                """
         | 
| 61 | 
            +
                if isinstance(config, DictConfig):
         | 
| 62 | 
            +
                    inherit = config.pop("__inherit__", None)
         | 
| 63 | 
            +
                    if inherit:
         | 
| 64 | 
            +
                        assert isinstance(inherit, str)
         | 
| 65 | 
            +
                        inherit = load_config(inherit)
         | 
| 66 | 
            +
                        if len(config.keys()) > 0:
         | 
| 67 | 
            +
                            config = OmegaConf.merge(inherit, config)
         | 
| 68 | 
            +
                        else:
         | 
| 69 | 
            +
                            config = inherit
         | 
| 70 | 
            +
                return config
         | 
| 71 | 
            +
             | 
| 72 | 
            +
             | 
| 73 | 
            +
            def import_item(path: str, name: str) -> Any:
         | 
| 74 | 
            +
                """
         | 
| 75 | 
            +
                Import a python item. Example: import_item("path.to.file", "MyClass") -> MyClass
         | 
| 76 | 
            +
                """
         | 
| 77 | 
            +
                return getattr(importlib.import_module(path), name)
         | 
| 78 | 
            +
             | 
| 79 | 
            +
             | 
| 80 | 
            +
            def create_object(config: DictConfig) -> Any:
         | 
| 81 | 
            +
                """
         | 
| 82 | 
            +
                Create an object from config.
         | 
| 83 | 
            +
                The config is expected to contains the following:
         | 
| 84 | 
            +
                __object__:
         | 
| 85 | 
            +
                  path: path.to.module
         | 
| 86 | 
            +
                  name: MyClass
         | 
| 87 | 
            +
                  args: as_config | as_params (default to as_config)
         | 
| 88 | 
            +
                """
         | 
| 89 | 
            +
                item = import_item(
         | 
| 90 | 
            +
                    path=config.__object__.path,
         | 
| 91 | 
            +
                    name=config.__object__.name,
         | 
| 92 | 
            +
                )
         | 
| 93 | 
            +
                args = config.__object__.get("args", "as_config")
         | 
| 94 | 
            +
                if args == "as_config":
         | 
| 95 | 
            +
                    return item(config)
         | 
| 96 | 
            +
                if args == "as_params":
         | 
| 97 | 
            +
                    config = OmegaConf.to_object(config)
         | 
| 98 | 
            +
                    config.pop("__object__")
         | 
| 99 | 
            +
                    return item(**config)
         | 
| 100 | 
            +
                raise NotImplementedError(f"Unknown args type: {args}")
         | 
| 101 | 
            +
             | 
| 102 | 
            +
             | 
| 103 | 
            +
            def create_dataset(path: str, *args, **kwargs) -> Any:
         | 
| 104 | 
            +
                """
         | 
| 105 | 
            +
                Create a dataset. Requires the file to contain a "create_dataset" function.
         | 
| 106 | 
            +
                """
         | 
| 107 | 
            +
                return import_item(path, "create_dataset")(*args, **kwargs)
         | 
    	
        humo/common/distributed/__init__.py
    ADDED
    
    | @@ -0,0 +1,41 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
         | 
| 2 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 3 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 4 | 
            +
            # You may obtain a copy of the License at
         | 
| 5 | 
            +
            # http://www.apache.org/licenses/LICENSE-2.0
         | 
| 6 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 7 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 8 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 9 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 10 | 
            +
            # limitations under the License.
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            # Codes adapted from [SeedVR]
         | 
| 13 | 
            +
            # https://github.com/ByteDance-Seed/SeedVR/tree/main/common/distributed
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            """
         | 
| 16 | 
            +
            Distributed package.
         | 
| 17 | 
            +
            """
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            from .basic import (
         | 
| 20 | 
            +
                barrier_if_distributed,
         | 
| 21 | 
            +
                convert_to_ddp,
         | 
| 22 | 
            +
                get_device,
         | 
| 23 | 
            +
                get_global_rank,
         | 
| 24 | 
            +
                get_local_rank,
         | 
| 25 | 
            +
                get_world_size,
         | 
| 26 | 
            +
                init_torch,
         | 
| 27 | 
            +
                meta_param_init_fn,
         | 
| 28 | 
            +
                meta_non_persistent_buffer_init_fn
         | 
| 29 | 
            +
            )
         | 
| 30 | 
            +
             | 
| 31 | 
            +
            __all__ = [
         | 
| 32 | 
            +
                "barrier_if_distributed",
         | 
| 33 | 
            +
                "convert_to_ddp",
         | 
| 34 | 
            +
                "get_device",
         | 
| 35 | 
            +
                "get_global_rank",
         | 
| 36 | 
            +
                "get_local_rank",
         | 
| 37 | 
            +
                "get_world_size",
         | 
| 38 | 
            +
                "init_torch",
         | 
| 39 | 
            +
                "meta_param_init_fn",
         | 
| 40 | 
            +
                "meta_non_persistent_buffer_init_fn",
         | 
| 41 | 
            +
            ]
         | 
    	
        humo/common/distributed/advanced.py
    ADDED
    
    | @@ -0,0 +1,484 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
         | 
| 2 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 3 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 4 | 
            +
            # You may obtain a copy of the License at
         | 
| 5 | 
            +
            # http://www.apache.org/licenses/LICENSE-2.0
         | 
| 6 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 7 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 8 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 9 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 10 | 
            +
            # limitations under the License.
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            # Codes adapted from [SeedVR]
         | 
| 13 | 
            +
            # https://github.com/ByteDance-Seed/SeedVR/tree/main/common/distributed
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            """
         | 
| 16 | 
            +
            Advanced distributed functions for sequence parallel.
         | 
| 17 | 
            +
            """
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            import torch
         | 
| 20 | 
            +
            from typing import Any, List, Optional, Tuple, Union
         | 
| 21 | 
            +
            import torch.distributed as dist
         | 
| 22 | 
            +
            from torch import Tensor
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            from .basic import get_global_rank, get_world_size
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            _DATA_PARALLEL_GROUP = None
         | 
| 28 | 
            +
            _SEQUENCE_PARALLEL_GROUP = None
         | 
| 29 | 
            +
            _SEQUENCE_PARALLEL_CPU_GROUP = None
         | 
| 30 | 
            +
             | 
| 31 | 
            +
             | 
| 32 | 
            +
            _CFG_PARALLEL_GROUP = None
         | 
| 33 | 
            +
            _CFG_PARALLEL_CPU_GROUP = None
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            def get_data_parallel_group() -> Optional[dist.ProcessGroup]:
         | 
| 36 | 
            +
                """
         | 
| 37 | 
            +
                Get data parallel process group.
         | 
| 38 | 
            +
                """
         | 
| 39 | 
            +
                return _DATA_PARALLEL_GROUP
         | 
| 40 | 
            +
             | 
| 41 | 
            +
             | 
| 42 | 
            +
            def get_sequence_parallel_group() -> Optional[dist.ProcessGroup]:
         | 
| 43 | 
            +
                """
         | 
| 44 | 
            +
                Get sequence parallel process group.
         | 
| 45 | 
            +
                """
         | 
| 46 | 
            +
                return _SEQUENCE_PARALLEL_GROUP
         | 
| 47 | 
            +
             | 
| 48 | 
            +
             | 
| 49 | 
            +
            def get_sequence_parallel_cpu_group() -> Optional[dist.ProcessGroup]:
         | 
| 50 | 
            +
                """
         | 
| 51 | 
            +
                Get sequence parallel CPU process group.
         | 
| 52 | 
            +
                """
         | 
| 53 | 
            +
                return _SEQUENCE_PARALLEL_CPU_GROUP
         | 
| 54 | 
            +
             | 
| 55 | 
            +
             | 
| 56 | 
            +
            def get_data_parallel_rank() -> int:
         | 
| 57 | 
            +
                """
         | 
| 58 | 
            +
                Get data parallel rank.
         | 
| 59 | 
            +
                """
         | 
| 60 | 
            +
                group = get_data_parallel_group()
         | 
| 61 | 
            +
                return dist.get_rank(group) if group else get_global_rank()
         | 
| 62 | 
            +
             | 
| 63 | 
            +
             | 
| 64 | 
            +
            def get_data_parallel_world_size() -> int:
         | 
| 65 | 
            +
                """
         | 
| 66 | 
            +
                Get data parallel world size.
         | 
| 67 | 
            +
                """
         | 
| 68 | 
            +
                group = get_data_parallel_group()
         | 
| 69 | 
            +
                return dist.get_world_size(group) if group else get_world_size()
         | 
| 70 | 
            +
             | 
| 71 | 
            +
             | 
| 72 | 
            +
            def get_sequence_parallel_rank() -> int:
         | 
| 73 | 
            +
                """
         | 
| 74 | 
            +
                Get sequence parallel rank.
         | 
| 75 | 
            +
                """
         | 
| 76 | 
            +
                group = get_sequence_parallel_group()
         | 
| 77 | 
            +
                return dist.get_rank(group) if group else 0
         | 
| 78 | 
            +
             | 
| 79 | 
            +
             | 
| 80 | 
            +
            def get_sequence_parallel_world_size() -> int:
         | 
| 81 | 
            +
                """
         | 
| 82 | 
            +
                Get sequence parallel world size.
         | 
| 83 | 
            +
                """
         | 
| 84 | 
            +
                group = get_sequence_parallel_group()
         | 
| 85 | 
            +
                return dist.get_world_size(group) if group else 1
         | 
| 86 | 
            +
             | 
| 87 | 
            +
             | 
| 88 | 
            +
            def init_unified_parallel(unified_parallel_size):
         | 
| 89 | 
            +
                global _SEQUENCE_PARALLEL_GROUP
         | 
| 90 | 
            +
                global _SEQUENCE_PARALLEL_CPU_GROUP
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                if unified_parallel_size == 1:
         | 
| 93 | 
            +
                    return
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                assert dist.is_initialized()
         | 
| 96 | 
            +
                world_size = dist.get_world_size()
         | 
| 97 | 
            +
                rank = dist.get_rank()
         | 
| 98 | 
            +
                assert world_size % unified_parallel_size == 0
         | 
| 99 | 
            +
                data_parallel_size = world_size // unified_parallel_size
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                for i in range(data_parallel_size):
         | 
| 102 | 
            +
                    # build unified parallel group
         | 
| 103 | 
            +
                    start_rank = i * unified_parallel_size
         | 
| 104 | 
            +
                    end_rank = start_rank + unified_parallel_size
         | 
| 105 | 
            +
                    unified_parallel_ranks = range(start_rank, end_rank)
         | 
| 106 | 
            +
                    unified_parallel_group = dist.new_group(unified_parallel_ranks)
         | 
| 107 | 
            +
                    unified_parallel_cpu_group = dist.new_group(unified_parallel_ranks, backend="gloo")
         | 
| 108 | 
            +
                    if rank in unified_parallel_ranks:
         | 
| 109 | 
            +
                        _SEQUENCE_PARALLEL_GROUP = unified_parallel_group
         | 
| 110 | 
            +
                        _SEQUENCE_PARALLEL_CPU_GROUP = unified_parallel_cpu_group
         | 
| 111 | 
            +
             | 
| 112 | 
            +
             | 
| 113 | 
            +
            def get_unified_parallel_group():
         | 
| 114 | 
            +
                global _SEQUENCE_PARALLEL_GROUP
         | 
| 115 | 
            +
                return _SEQUENCE_PARALLEL_GROUP
         | 
| 116 | 
            +
             | 
| 117 | 
            +
             | 
| 118 | 
            +
            def get_unified_parallel_cpu_group():
         | 
| 119 | 
            +
                global _SEQUENCE_PARALLEL_CPU_GROUP
         | 
| 120 | 
            +
                return _SEQUENCE_PARALLEL_CPU_GROUP
         | 
| 121 | 
            +
             | 
| 122 | 
            +
             | 
| 123 | 
            +
            def get_unified_parallel_rank():
         | 
| 124 | 
            +
                group = get_unified_parallel_group()
         | 
| 125 | 
            +
                return dist.get_rank(group) if group else 0
         | 
| 126 | 
            +
             | 
| 127 | 
            +
             | 
| 128 | 
            +
            def get_unified_parallel_world_size():
         | 
| 129 | 
            +
                group = get_unified_parallel_group()
         | 
| 130 | 
            +
                return dist.get_world_size(group) if group else 1
         | 
| 131 | 
            +
             | 
| 132 | 
            +
             | 
| 133 | 
            +
            def is_unified_parallel_initialized():
         | 
| 134 | 
            +
                group = get_unified_parallel_group()
         | 
| 135 | 
            +
                return group is not None
         | 
| 136 | 
            +
             | 
| 137 | 
            +
             | 
| 138 | 
            +
            def pad_tensor(x: Tensor, dim: int, padding_size: int):
         | 
| 139 | 
            +
                shape = list(x.shape)
         | 
| 140 | 
            +
                shape[dim] = padding_size
         | 
| 141 | 
            +
                pad = torch.zeros(shape, dtype=x.dtype, device=x.device)
         | 
| 142 | 
            +
                return torch.cat([x, pad], dim=dim)
         | 
| 143 | 
            +
             | 
| 144 | 
            +
             | 
| 145 | 
            +
            class Slice(torch.autograd.Function):
         | 
| 146 | 
            +
                @staticmethod
         | 
| 147 | 
            +
                def forward(ctx: Any, group: dist.ProcessGroup, local_input: Tensor, dim: int, scale_grad: bool) -> Tensor:
         | 
| 148 | 
            +
                    ctx.group = group
         | 
| 149 | 
            +
                    ctx.rank = dist.get_rank(group)
         | 
| 150 | 
            +
                    seq_world_size = dist.get_world_size(group)
         | 
| 151 | 
            +
                    ctx.seq_world_size = seq_world_size
         | 
| 152 | 
            +
                    ctx.dim = dim
         | 
| 153 | 
            +
                    ctx.scale_grad = scale_grad
         | 
| 154 | 
            +
                    dim_size = local_input.shape[dim]
         | 
| 155 | 
            +
                    if not ctx.group:
         | 
| 156 | 
            +
                        return local_input
         | 
| 157 | 
            +
                    return local_input.split(dim_size // seq_world_size, dim=dim)[ctx.rank].contiguous()
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                @staticmethod
         | 
| 160 | 
            +
                def backward(ctx: Any, grad_output: Tensor) -> Tuple[None, Tensor, None]:
         | 
| 161 | 
            +
                    if not ctx.group:
         | 
| 162 | 
            +
                        return None, grad_output, None, None
         | 
| 163 | 
            +
                    dim_size = list(grad_output.size())
         | 
| 164 | 
            +
                    split_size = dim_size[0]
         | 
| 165 | 
            +
                    dim_size[0] = dim_size[0] * ctx.seq_world_size
         | 
| 166 | 
            +
                    output = torch.empty(dim_size, dtype=grad_output.dtype, device=torch.cuda.current_device())
         | 
| 167 | 
            +
                    dist.all_gather_into_tensor(output, grad_output, group=ctx.group)
         | 
| 168 | 
            +
                    if ctx.scale_grad:
         | 
| 169 | 
            +
                        output = output / ctx.seq_world_size
         | 
| 170 | 
            +
                    return (None, torch.cat(output.split(split_size), dim=ctx.dim), None, None)
         | 
| 171 | 
            +
             | 
| 172 | 
            +
             | 
| 173 | 
            +
            def gather_outputs(
         | 
| 174 | 
            +
                x: Tensor,
         | 
| 175 | 
            +
                gather_dim: int,
         | 
| 176 | 
            +
                padding_dim: Optional[int] = None,
         | 
| 177 | 
            +
                unpad_dim_size: Optional[int] = None,
         | 
| 178 | 
            +
                scale_grad=True,
         | 
| 179 | 
            +
            ):
         | 
| 180 | 
            +
                """
         | 
| 181 | 
            +
                A func to gather the outputs for the model result in sequence parallel
         | 
| 182 | 
            +
                """
         | 
| 183 | 
            +
                group = get_unified_parallel_group()
         | 
| 184 | 
            +
                if not group:
         | 
| 185 | 
            +
                    return x
         | 
| 186 | 
            +
                x = Gather.apply(group, x, gather_dim, scale_grad)
         | 
| 187 | 
            +
                if padding_dim is not None:
         | 
| 188 | 
            +
                    x = unpadding_tensor_for_seqeunce_parallel(x, padding_dim, unpad_dim_size)
         | 
| 189 | 
            +
                return x
         | 
| 190 | 
            +
             | 
| 191 | 
            +
             | 
| 192 | 
            +
            def unpadding_tensor_for_seqeunce_parallel(x: Tensor, dim: int, unpadded_dim_size: int):
         | 
| 193 | 
            +
                """
         | 
| 194 | 
            +
                A func to remove the padding part of the tensor based on its original shape
         | 
| 195 | 
            +
                """
         | 
| 196 | 
            +
                group = get_unified_parallel_group()
         | 
| 197 | 
            +
                if group is None:
         | 
| 198 | 
            +
                    return x
         | 
| 199 | 
            +
                sp_world = get_unified_parallel_world_size()
         | 
| 200 | 
            +
                if unpadded_dim_size % sp_world == 0:
         | 
| 201 | 
            +
                    return x
         | 
| 202 | 
            +
                padding_size = sp_world - (unpadded_dim_size % sp_world)
         | 
| 203 | 
            +
                assert (padding_size + unpadded_dim_size) % sp_world == 0
         | 
| 204 | 
            +
                return unpad_tensor(x, dim=dim, padding_size=padding_size)
         | 
| 205 | 
            +
             | 
| 206 | 
            +
             | 
| 207 | 
            +
            def gather_seq_scatter_heads_qkv(
         | 
| 208 | 
            +
                qkv_tensor: Tensor,
         | 
| 209 | 
            +
                seq_dim: int,
         | 
| 210 | 
            +
                unpadded_dim_size: Optional[int] = None,
         | 
| 211 | 
            +
                restore_shape: bool = True,
         | 
| 212 | 
            +
                async_op: bool = False,
         | 
| 213 | 
            +
            ):
         | 
| 214 | 
            +
                """
         | 
| 215 | 
            +
                A func to sync splited qkv tensor
         | 
| 216 | 
            +
                qkv_tensor: the tensor we want to do alltoall with. The last dim must
         | 
| 217 | 
            +
                    be the projection_idx, which we will split into 3 part. After
         | 
| 218 | 
            +
                    spliting, the gather idx will be projecttion_idx + 1
         | 
| 219 | 
            +
                seq_dim: gather_dim for all2all comm
         | 
| 220 | 
            +
                restore_shape: if True, output will has the same shape length as input
         | 
| 221 | 
            +
                """
         | 
| 222 | 
            +
                group = get_unified_parallel_group()
         | 
| 223 | 
            +
                if not group:
         | 
| 224 | 
            +
                    return qkv_tensor
         | 
| 225 | 
            +
                world = get_unified_parallel_world_size()
         | 
| 226 | 
            +
                orig_shape = qkv_tensor.shape
         | 
| 227 | 
            +
                scatter_dim = qkv_tensor.dim()
         | 
| 228 | 
            +
                bef_all2all_shape = list(orig_shape)
         | 
| 229 | 
            +
                qkv_proj_dim = bef_all2all_shape[-1]
         | 
| 230 | 
            +
                bef_all2all_shape = bef_all2all_shape[:-1] + [3, qkv_proj_dim // 3]
         | 
| 231 | 
            +
                qkv_tensor = qkv_tensor.view(bef_all2all_shape)
         | 
| 232 | 
            +
                if async_op:
         | 
| 233 | 
            +
                    return SeqAllToAll.apply(group, qkv_tensor, scatter_dim, seq_dim, async_op)
         | 
| 234 | 
            +
                else:
         | 
| 235 | 
            +
                    qkv_tensor = SeqAllToAll.apply(group, qkv_tensor, scatter_dim, seq_dim, async_op)
         | 
| 236 | 
            +
             | 
| 237 | 
            +
                    if restore_shape:
         | 
| 238 | 
            +
                        out_shape = list(orig_shape)
         | 
| 239 | 
            +
                        out_shape[seq_dim] *= world
         | 
| 240 | 
            +
                        out_shape[-1] = qkv_proj_dim // world
         | 
| 241 | 
            +
                        qkv_tensor = qkv_tensor.view(out_shape)
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                    # remove padding
         | 
| 244 | 
            +
                    if unpadded_dim_size and unpadded_dim_size % world != 0:
         | 
| 245 | 
            +
                        padding_size = qkv_tensor.size(seq_dim) - unpadded_dim_size
         | 
| 246 | 
            +
                        qkv_tensor = unpad_tensor(qkv_tensor, seq_dim, padding_size)
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                    return qkv_tensor
         | 
| 249 | 
            +
             | 
| 250 | 
            +
             | 
| 251 | 
            +
            def gather_seq_scatter_double_head(
         | 
| 252 | 
            +
                qkv_tensor: Tensor,
         | 
| 253 | 
            +
                seq_dim: int,
         | 
| 254 | 
            +
                unpadded_dim_size: Optional[int] = None,
         | 
| 255 | 
            +
                restore_shape: bool = True,
         | 
| 256 | 
            +
                async_op: bool = False,
         | 
| 257 | 
            +
            ):
         | 
| 258 | 
            +
                """
         | 
| 259 | 
            +
                A func to sync splited qkv tensor
         | 
| 260 | 
            +
                qkv_tensor: the tensor we want to do alltoall with. The last dim must
         | 
| 261 | 
            +
                    be the projection_idx, which we will split into 3 part. After
         | 
| 262 | 
            +
                    spliting, the gather idx will be projecttion_idx + 1
         | 
| 263 | 
            +
                seq_dim: gather_dim for all2all comm
         | 
| 264 | 
            +
                restore_shape: if True, output will has the same shape length as input
         | 
| 265 | 
            +
                """
         | 
| 266 | 
            +
                qkv1_shape = qkv_tensor.shape
         | 
| 267 | 
            +
                group = get_unified_parallel_group()
         | 
| 268 | 
            +
                if not group:
         | 
| 269 | 
            +
                    return qkv_tensor
         | 
| 270 | 
            +
                world = get_unified_parallel_world_size()
         | 
| 271 | 
            +
                orig_shape = qkv_tensor.shape
         | 
| 272 | 
            +
                scatter_dim = qkv_tensor.dim()
         | 
| 273 | 
            +
                bef_all2all_shape = list(orig_shape)
         | 
| 274 | 
            +
                qkv_proj_dim = bef_all2all_shape[-1]
         | 
| 275 | 
            +
                bef_all2all_shape = bef_all2all_shape[:-1] + [2, qkv_proj_dim // 2]
         | 
| 276 | 
            +
                qkv_tensor = qkv_tensor.view(bef_all2all_shape)
         | 
| 277 | 
            +
                qkv2_shape = qkv_tensor.shape
         | 
| 278 | 
            +
                if async_op:
         | 
| 279 | 
            +
                    return SeqAllToAll.apply(group, qkv_tensor, scatter_dim, seq_dim, async_op)
         | 
| 280 | 
            +
                else:
         | 
| 281 | 
            +
                    qkv_tensor = SeqAllToAll.apply(group, qkv_tensor, scatter_dim, seq_dim, async_op)
         | 
| 282 | 
            +
                    qkv3_shape = qkv_tensor.shape
         | 
| 283 | 
            +
             | 
| 284 | 
            +
                    if restore_shape:
         | 
| 285 | 
            +
                        out_shape = list(orig_shape)
         | 
| 286 | 
            +
                        out_shape[seq_dim] *= world
         | 
| 287 | 
            +
                        out_shape[-1] = qkv_proj_dim // world
         | 
| 288 | 
            +
                        qkv_tensor = qkv_tensor.view(out_shape)
         | 
| 289 | 
            +
                        qkv4_shape = qkv_tensor.shape
         | 
| 290 | 
            +
             | 
| 291 | 
            +
                    # remove padding
         | 
| 292 | 
            +
                    if unpadded_dim_size and unpadded_dim_size % world != 0:
         | 
| 293 | 
            +
                        padding_size = qkv_tensor.size(seq_dim) - unpadded_dim_size
         | 
| 294 | 
            +
                        qkv_tensor = unpad_tensor(qkv_tensor, seq_dim, padding_size)
         | 
| 295 | 
            +
                        qkv5_shape = qkv_tensor.shape
         | 
| 296 | 
            +
                    
         | 
| 297 | 
            +
                    return qkv_tensor
         | 
| 298 | 
            +
             | 
| 299 | 
            +
             | 
| 300 | 
            +
            class SeqAllToAll(torch.autograd.Function):
         | 
| 301 | 
            +
                @staticmethod
         | 
| 302 | 
            +
                def forward(
         | 
| 303 | 
            +
                    ctx: Any,
         | 
| 304 | 
            +
                    group: dist.ProcessGroup,
         | 
| 305 | 
            +
                    local_input: Tensor,
         | 
| 306 | 
            +
                    scatter_dim: int,
         | 
| 307 | 
            +
                    gather_dim: int,
         | 
| 308 | 
            +
                    async_op: bool,
         | 
| 309 | 
            +
                ) -> Tensor:
         | 
| 310 | 
            +
                    ctx.group = group
         | 
| 311 | 
            +
                    ctx.scatter_dim = scatter_dim
         | 
| 312 | 
            +
                    ctx.gather_dim = gather_dim
         | 
| 313 | 
            +
                    ctx.async_op = async_op
         | 
| 314 | 
            +
                    return all_to_all_tensor(local_input, scatter_dim, gather_dim, group, async_op)
         | 
| 315 | 
            +
             | 
| 316 | 
            +
                @staticmethod
         | 
| 317 | 
            +
                def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]:
         | 
| 318 | 
            +
                    if ctx.async_op:
         | 
| 319 | 
            +
                        input_t = torch.cat(grad_output[1:], dim=ctx.gather_dim).contiguous()
         | 
| 320 | 
            +
                    else:
         | 
| 321 | 
            +
                        input_t = grad_output[0]
         | 
| 322 | 
            +
                    return (
         | 
| 323 | 
            +
                        None,
         | 
| 324 | 
            +
                        all_to_all_tensor(input_t, ctx.gather_dim, ctx.scatter_dim, ctx.group, False),
         | 
| 325 | 
            +
                        None,
         | 
| 326 | 
            +
                        None,
         | 
| 327 | 
            +
                        None,
         | 
| 328 | 
            +
                        None,
         | 
| 329 | 
            +
                    )
         | 
| 330 | 
            +
             | 
| 331 | 
            +
             | 
| 332 | 
            +
            def all_to_all_tensor(
         | 
| 333 | 
            +
                x: Tensor,
         | 
| 334 | 
            +
                scatter_dim: int,
         | 
| 335 | 
            +
                gather_dim: int,
         | 
| 336 | 
            +
                group: dist.ProcessGroup,
         | 
| 337 | 
            +
                async_op: bool = False,
         | 
| 338 | 
            +
            ):
         | 
| 339 | 
            +
                if scatter_dim <= 1 and gather_dim <= 1:
         | 
| 340 | 
            +
                    return _all_to_all_single(x, scatter_dim, gather_dim, group, async_op)
         | 
| 341 | 
            +
                else:
         | 
| 342 | 
            +
                    return _all_to_all(x, scatter_dim, gather_dim, group, async_op)  # 走这里
         | 
| 343 | 
            +
             | 
| 344 | 
            +
             | 
| 345 | 
            +
            def _all_to_all(
         | 
| 346 | 
            +
                local_input: Tensor,
         | 
| 347 | 
            +
                scatter_dim: int,
         | 
| 348 | 
            +
                gather_dim: int,
         | 
| 349 | 
            +
                group: dist.ProcessGroup,
         | 
| 350 | 
            +
                async_op: bool = False,
         | 
| 351 | 
            +
            ):
         | 
| 352 | 
            +
                seq_world_size = dist.get_world_size(group)
         | 
| 353 | 
            +
                input_list = [t.contiguous() for t in torch.tensor_split(local_input, seq_world_size, scatter_dim)]
         | 
| 354 | 
            +
                output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)]
         | 
| 355 | 
            +
                comm = dist.all_to_all(output_list, input_list, group=group, async_op=async_op)
         | 
| 356 | 
            +
                if async_op:
         | 
| 357 | 
            +
             | 
| 358 | 
            +
                    def wait():
         | 
| 359 | 
            +
                        comm.wait()
         | 
| 360 | 
            +
                        return torch.cat(output_list, dim=gather_dim).contiguous()
         | 
| 361 | 
            +
             | 
| 362 | 
            +
                    return wait
         | 
| 363 | 
            +
                return torch.cat(output_list, dim=gather_dim).contiguous()
         | 
| 364 | 
            +
             | 
| 365 | 
            +
             | 
| 366 | 
            +
            def _all_to_all_single(x: Tensor, scatter_dim: int, gather_dim: int, group: dist.ProcessGroup, async_op: bool = False):
         | 
| 367 | 
            +
                """
         | 
| 368 | 
            +
                A function to do all-to-all on the first two dim
         | 
| 369 | 
            +
                """
         | 
| 370 | 
            +
                sp_world_size = dist.get_world_size(group)
         | 
| 371 | 
            +
                assert scatter_dim <= 1, "scatter_dim must be 0 or 1 when using all_to_all_single!"
         | 
| 372 | 
            +
                assert gather_dim <= 1, "gather_dim must be 0 or 1 when using all_to_all_single!"
         | 
| 373 | 
            +
                if scatter_dim != 0:
         | 
| 374 | 
            +
                    gather_dim_bef = x.shape[gather_dim]
         | 
| 375 | 
            +
                    scatter_dim_bef = x.shape[scatter_dim]
         | 
| 376 | 
            +
                    x = (
         | 
| 377 | 
            +
                        x.reshape([gather_dim_bef, sp_world_size, scatter_dim_bef // sp_world_size] + list(x.shape[2:]))
         | 
| 378 | 
            +
                        .transpose(0, 1)
         | 
| 379 | 
            +
                        .reshape([gather_dim_bef * sp_world_size, scatter_dim_bef // sp_world_size] + list(x.shape[2:]))
         | 
| 380 | 
            +
                        .contiguous()
         | 
| 381 | 
            +
                    )
         | 
| 382 | 
            +
             | 
| 383 | 
            +
                output = torch.empty_like(x)
         | 
| 384 | 
            +
                comm = dist.all_to_all_single(output, x.contiguous(), group=group, async_op=async_op)
         | 
| 385 | 
            +
             | 
| 386 | 
            +
                if async_op:
         | 
| 387 | 
            +
             | 
| 388 | 
            +
                    def wait():
         | 
| 389 | 
            +
                        comm.wait()
         | 
| 390 | 
            +
                        if scatter_dim == 0:
         | 
| 391 | 
            +
                            return torch.cat(output.split(x.size(0) // sp_world_size), dim=gather_dim)
         | 
| 392 | 
            +
                        else:
         | 
| 393 | 
            +
                            return output
         | 
| 394 | 
            +
             | 
| 395 | 
            +
                    return wait
         | 
| 396 | 
            +
             | 
| 397 | 
            +
                if scatter_dim == 0:
         | 
| 398 | 
            +
                    output = torch.cat(output.split(x.size(0) // sp_world_size), dim=gather_dim)
         | 
| 399 | 
            +
                return output
         | 
| 400 | 
            +
             | 
| 401 | 
            +
             | 
| 402 | 
            +
            def gather_heads_scatter_seq(x: Tensor, head_dim: int, seq_dim: int) -> Tensor:
         | 
| 403 | 
            +
                """
         | 
| 404 | 
            +
                A func to sync attention result with alltoall in sequence parallel
         | 
| 405 | 
            +
                """
         | 
| 406 | 
            +
                group = get_unified_parallel_group()
         | 
| 407 | 
            +
                if not group:
         | 
| 408 | 
            +
                    return x
         | 
| 409 | 
            +
                dim_size = x.size(seq_dim)
         | 
| 410 | 
            +
                sp_world = get_unified_parallel_world_size()
         | 
| 411 | 
            +
                if dim_size % sp_world != 0:
         | 
| 412 | 
            +
                    padding_size = sp_world - (dim_size % sp_world)
         | 
| 413 | 
            +
                    x = pad_tensor(x, seq_dim, padding_size)
         | 
| 414 | 
            +
                return SeqAllToAll.apply(group, x, seq_dim, head_dim, False)
         | 
| 415 | 
            +
             | 
| 416 | 
            +
             | 
| 417 | 
            +
            def unpad_tensor(x: Tensor, dim: int, padding_size: int):
         | 
| 418 | 
            +
                slc = [slice(None)] * len(x.shape)
         | 
| 419 | 
            +
                slc[dim] = slice(0, -padding_size)
         | 
| 420 | 
            +
                return x[slc]
         | 
| 421 | 
            +
             | 
| 422 | 
            +
             | 
| 423 | 
            +
            class Gather(torch.autograd.Function):
         | 
| 424 | 
            +
                @staticmethod
         | 
| 425 | 
            +
                def forward(
         | 
| 426 | 
            +
                    ctx: Any,
         | 
| 427 | 
            +
                    group: dist.ProcessGroup,
         | 
| 428 | 
            +
                    local_input: Tensor,
         | 
| 429 | 
            +
                    dim: int,
         | 
| 430 | 
            +
                    grad_scale: Optional[bool] = False,
         | 
| 431 | 
            +
                ) -> Tensor:
         | 
| 432 | 
            +
                    ctx.group = group
         | 
| 433 | 
            +
                    ctx.rank = dist.get_rank(group)
         | 
| 434 | 
            +
                    ctx.dim = dim
         | 
| 435 | 
            +
                    ctx.grad_scale = grad_scale
         | 
| 436 | 
            +
                    seq_world_size = dist.get_world_size(group)
         | 
| 437 | 
            +
                    ctx.seq_world_size = seq_world_size
         | 
| 438 | 
            +
                    dim_size = list(local_input.size())
         | 
| 439 | 
            +
                    split_size = dim_size[0]
         | 
| 440 | 
            +
                    ctx.part_size = dim_size[dim]
         | 
| 441 | 
            +
                    dim_size[0] = dim_size[0] * seq_world_size
         | 
| 442 | 
            +
                    output = torch.empty(dim_size, dtype=local_input.dtype, device=torch.cuda.current_device())
         | 
| 443 | 
            +
                    dist.all_gather_into_tensor(output, local_input.contiguous(), group=ctx.group)
         | 
| 444 | 
            +
                    return torch.cat(output.split(split_size), dim=dim)
         | 
| 445 | 
            +
             | 
| 446 | 
            +
                @staticmethod
         | 
| 447 | 
            +
                def backward(ctx: Any, grad_output: Tensor) -> Tuple[None, Tensor]:
         | 
| 448 | 
            +
                    if ctx.grad_scale:
         | 
| 449 | 
            +
                        grad_output = grad_output * ctx.seq_world_size
         | 
| 450 | 
            +
                    return (
         | 
| 451 | 
            +
                        None,
         | 
| 452 | 
            +
                        grad_output.split(ctx.part_size, dim=ctx.dim)[ctx.rank].contiguous(),
         | 
| 453 | 
            +
                        None,
         | 
| 454 | 
            +
                        None,
         | 
| 455 | 
            +
                    )
         | 
| 456 | 
            +
             | 
| 457 | 
            +
             | 
| 458 | 
            +
            def slice_tensor(tensor, dim, start, end):
         | 
| 459 | 
            +
                indices = slice(start, end)
         | 
| 460 | 
            +
                return tensor[(slice(None),) * dim + (indices,)]
         | 
| 461 | 
            +
             | 
| 462 | 
            +
             | 
| 463 | 
            +
            def init_model_shard_cpu_group(sharding_strategy: str, device_mesh: Optional[Tuple] = None):
         | 
| 464 | 
            +
                """
         | 
| 465 | 
            +
                Initialize CPU process group of model sharding.
         | 
| 466 | 
            +
                """
         | 
| 467 | 
            +
                global _MODEL_SHARD_CPU_GROUP
         | 
| 468 | 
            +
                assert dist.is_initialized()
         | 
| 469 | 
            +
                world_size = dist.get_world_size()
         | 
| 470 | 
            +
                rank = dist.get_rank()
         | 
| 471 | 
            +
                if device_mesh is not None:
         | 
| 472 | 
            +
                    num_shards_per_group = device_mesh[1]
         | 
| 473 | 
            +
                elif "HYBRID" in sharding_strategy:
         | 
| 474 | 
            +
                    num_shards_per_group = min(8, world_size)
         | 
| 475 | 
            +
                else:
         | 
| 476 | 
            +
                    num_shards_per_group = world_size
         | 
| 477 | 
            +
                num_groups = world_size // num_shards_per_group
         | 
| 478 | 
            +
                for i in range(num_groups):
         | 
| 479 | 
            +
                    start_rank = i * num_shards_per_group
         | 
| 480 | 
            +
                    end_rank = (i + 1) * num_shards_per_group
         | 
| 481 | 
            +
                    ranks = range(start_rank, end_rank)
         | 
| 482 | 
            +
                    cpu_group = dist.new_group(ranks, backend="gloo")
         | 
| 483 | 
            +
                    if rank in ranks:
         | 
| 484 | 
            +
                        _MODEL_SHARD_CPU_GROUP = cpu_group
         | 
    	
        humo/common/distributed/basic.py
    ADDED
    
    | @@ -0,0 +1,143 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
         | 
| 2 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 3 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 4 | 
            +
            # You may obtain a copy of the License at
         | 
| 5 | 
            +
            # http://www.apache.org/licenses/LICENSE-2.0
         | 
| 6 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 7 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 8 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 9 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 10 | 
            +
            # limitations under the License.
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            # Codes adapted from [SeedVR]
         | 
| 13 | 
            +
            # https://github.com/ByteDance-Seed/SeedVR/tree/main/common/distributed
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            """
         | 
| 16 | 
            +
            Distributed basic functions.
         | 
| 17 | 
            +
            """
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            import os
         | 
| 20 | 
            +
            import torch
         | 
| 21 | 
            +
            from torch import nn
         | 
| 22 | 
            +
            import torch.distributed as dist
         | 
| 23 | 
            +
            from torch.nn.parallel import DistributedDataParallel
         | 
| 24 | 
            +
            from torch.distributed.fsdp._common_utils import _is_fsdp_flattened
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            def get_global_rank() -> int:
         | 
| 28 | 
            +
                """
         | 
| 29 | 
            +
                Get the global rank, the global index of the GPU.
         | 
| 30 | 
            +
                """
         | 
| 31 | 
            +
                return int(os.environ.get("RANK", "0"))
         | 
| 32 | 
            +
             | 
| 33 | 
            +
             | 
| 34 | 
            +
            def get_local_rank() -> int:
         | 
| 35 | 
            +
                """
         | 
| 36 | 
            +
                Get the local rank, the local index of the GPU.
         | 
| 37 | 
            +
                """
         | 
| 38 | 
            +
                return int(os.environ.get("LOCAL_RANK", "0"))
         | 
| 39 | 
            +
             | 
| 40 | 
            +
             | 
| 41 | 
            +
            def get_world_size() -> int:
         | 
| 42 | 
            +
                """
         | 
| 43 | 
            +
                Get the world size, the total amount of GPUs.
         | 
| 44 | 
            +
                """
         | 
| 45 | 
            +
                return int(os.environ.get("WORLD_SIZE", "1"))
         | 
| 46 | 
            +
             | 
| 47 | 
            +
             | 
| 48 | 
            +
            def get_device() -> torch.device:
         | 
| 49 | 
            +
                """
         | 
| 50 | 
            +
                Get current rank device.
         | 
| 51 | 
            +
                """
         | 
| 52 | 
            +
                return torch.device("cuda", get_local_rank())
         | 
| 53 | 
            +
             | 
| 54 | 
            +
             | 
| 55 | 
            +
            def barrier_if_distributed(*args, **kwargs):
         | 
| 56 | 
            +
                """
         | 
| 57 | 
            +
                Synchronizes all processes if under distributed context.
         | 
| 58 | 
            +
                """
         | 
| 59 | 
            +
                if dist.is_initialized():
         | 
| 60 | 
            +
                    return dist.barrier(*args, **kwargs)
         | 
| 61 | 
            +
             | 
| 62 | 
            +
             | 
| 63 | 
            +
            def init_torch(cudnn_benchmark=True):
         | 
| 64 | 
            +
                """
         | 
| 65 | 
            +
                Common PyTorch initialization configuration.
         | 
| 66 | 
            +
                """
         | 
| 67 | 
            +
                torch.backends.cuda.matmul.allow_tf32 = True
         | 
| 68 | 
            +
                torch.backends.cudnn.allow_tf32 = True
         | 
| 69 | 
            +
                torch.backends.cudnn.benchmark = cudnn_benchmark
         | 
| 70 | 
            +
                torch.cuda.set_device(get_local_rank())
         | 
| 71 | 
            +
                dist.init_process_group(
         | 
| 72 | 
            +
                    backend="nccl",
         | 
| 73 | 
            +
                    rank=get_global_rank(),
         | 
| 74 | 
            +
                    world_size=get_world_size(),
         | 
| 75 | 
            +
                )
         | 
| 76 | 
            +
             | 
| 77 | 
            +
             | 
| 78 | 
            +
            def convert_to_ddp(module: torch.nn.Module, **kwargs) -> DistributedDataParallel:
         | 
| 79 | 
            +
                return DistributedDataParallel(
         | 
| 80 | 
            +
                    module=module,
         | 
| 81 | 
            +
                    device_ids=[get_local_rank()],
         | 
| 82 | 
            +
                    output_device=get_local_rank(),
         | 
| 83 | 
            +
                    **kwargs,
         | 
| 84 | 
            +
                )
         | 
| 85 | 
            +
             | 
| 86 | 
            +
             | 
| 87 | 
            +
            def meta_param_init_fn(module: nn.Module) -> None:
         | 
| 88 | 
            +
                """
         | 
| 89 | 
            +
                Used for model inited onto meta device.
         | 
| 90 | 
            +
                Init meta param/buffer with empty tensor.
         | 
| 91 | 
            +
                We don't care numerical correctness in this func.
         | 
| 92 | 
            +
                FSDP will sync param/buffer state from rank0 to the other ranks.
         | 
| 93 | 
            +
                """
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                with torch.no_grad():
         | 
| 96 | 
            +
                    for submodule in module.modules():
         | 
| 97 | 
            +
                        for param_name, param in submodule.named_parameters(recurse=False):
         | 
| 98 | 
            +
                            if not _is_fsdp_flattened(param) and param.is_meta:
         | 
| 99 | 
            +
                                materialized_param = nn.Parameter(torch.empty_like(param, device="cpu"))
         | 
| 100 | 
            +
                                setattr(submodule, param_name, materialized_param)
         | 
| 101 | 
            +
                        for buffer_name, buffer in submodule.named_buffers(recurse=False):
         | 
| 102 | 
            +
                            if not _is_fsdp_flattened(buffer) and buffer.is_meta:
         | 
| 103 | 
            +
                                materialized_param = torch.empty_like(buffer, device="cpu")
         | 
| 104 | 
            +
                                setattr(submodule, buffer_name, materialized_param)
         | 
| 105 | 
            +
                        torch.cuda.empty_cache()
         | 
| 106 | 
            +
             | 
| 107 | 
            +
             | 
| 108 | 
            +
            def meta_non_persistent_buffer_init_fn(module: nn.Module) -> nn.Module:
         | 
| 109 | 
            +
                """
         | 
| 110 | 
            +
                Materialize meta device buffers that are not persistent in state_dict.
         | 
| 111 | 
            +
                Handles special cases like RotaryEmbedding.freqs.
         | 
| 112 | 
            +
                """
         | 
| 113 | 
            +
                with torch.no_grad():
         | 
| 114 | 
            +
                    for submodule in module.modules():
         | 
| 115 | 
            +
                        if hasattr(submodule, "freqs"):
         | 
| 116 | 
            +
                            freqs = getattr(submodule, "freqs")
         | 
| 117 | 
            +
                            if isinstance(freqs, torch.Tensor) and freqs.is_meta:
         | 
| 118 | 
            +
                                dim = submodule.dim
         | 
| 119 | 
            +
                                def rope_params(max_seq_len, dim, theta=10000):
         | 
| 120 | 
            +
                                    assert dim % 2 == 0
         | 
| 121 | 
            +
                                    freqs = torch.outer(
         | 
| 122 | 
            +
                                        torch.arange(max_seq_len),
         | 
| 123 | 
            +
                                        1.0 / torch.pow(theta,
         | 
| 124 | 
            +
                                            torch.arange(0, dim, 2).to(torch.float32).div(dim)))
         | 
| 125 | 
            +
                                    freqs = torch.polar(torch.ones_like(freqs), freqs)
         | 
| 126 | 
            +
                                    return freqs
         | 
| 127 | 
            +
                                
         | 
| 128 | 
            +
                                dim = 5120  # 1536 
         | 
| 129 | 
            +
                                num_heads = 40  # 12
         | 
| 130 | 
            +
                                # dim = 1536 
         | 
| 131 | 
            +
                                # num_heads = 12
         | 
| 132 | 
            +
                                d = dim // num_heads
         | 
| 133 | 
            +
                                freqs_tensor = torch.cat([
         | 
| 134 | 
            +
                                    rope_params(1024, d - 4 * (d // 6)),
         | 
| 135 | 
            +
                                    rope_params(1024, 2 * (d // 6)),
         | 
| 136 | 
            +
                                    rope_params(1024, 2 * (d // 6))
         | 
| 137 | 
            +
                                ], dim=1).to(dtype=torch.cfloat, device="cpu")
         | 
| 138 | 
            +
                                
         | 
| 139 | 
            +
                                setattr(submodule, "freqs", freqs_tensor)
         | 
| 140 | 
            +
                                print(f"Successfully materialized freqs for {submodule.__class__.__name__}")
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                assert not any(b.is_meta for n, b in module.named_buffers())
         | 
| 143 | 
            +
                return module
         | 
    	
        humo/common/logger.py
    ADDED
    
    | @@ -0,0 +1,44 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
         | 
| 2 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 3 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 4 | 
            +
            # You may obtain a copy of the License at
         | 
| 5 | 
            +
            # http://www.apache.org/licenses/LICENSE-2.0
         | 
| 6 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 7 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 8 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 9 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 10 | 
            +
            # limitations under the License.
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            # Codes adapted from [SeedVR]
         | 
| 13 | 
            +
            # https://github.com/ByteDance-Seed/SeedVR/blob/main/common/logger.py
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            """
         | 
| 16 | 
            +
            Logging utility functions.
         | 
| 17 | 
            +
            """
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            import logging
         | 
| 20 | 
            +
            import sys
         | 
| 21 | 
            +
            from typing import Optional
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            from common.distributed import get_global_rank, get_local_rank, get_world_size
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            _default_handler = logging.StreamHandler(sys.stdout)
         | 
| 26 | 
            +
            _default_handler.setFormatter(
         | 
| 27 | 
            +
                logging.Formatter(
         | 
| 28 | 
            +
                    "%(asctime)s "
         | 
| 29 | 
            +
                    + (f"[Rank:{get_global_rank()}]" if get_world_size() > 1 else "")
         | 
| 30 | 
            +
                    + (f"[LocalRank:{get_local_rank()}]" if get_world_size() > 1 else "")
         | 
| 31 | 
            +
                    + "[%(threadName).12s][%(name)s][%(levelname).5s] "
         | 
| 32 | 
            +
                    + "%(message)s"
         | 
| 33 | 
            +
                )
         | 
| 34 | 
            +
            )
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            def get_logger(name: Optional[str] = None) -> logging.Logger:
         | 
| 38 | 
            +
                """
         | 
| 39 | 
            +
                Get a logger.
         | 
| 40 | 
            +
                """
         | 
| 41 | 
            +
                logger = logging.getLogger(name)
         | 
| 42 | 
            +
                logger.addHandler(_default_handler)
         | 
| 43 | 
            +
                logger.setLevel(logging.INFO)
         | 
| 44 | 
            +
                return logger
         | 
    	
        humo/configs/inference/generate.yaml
    ADDED
    
    | @@ -0,0 +1,78 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            __object__:
         | 
| 2 | 
            +
              path: humo.generate
         | 
| 3 | 
            +
              name: Generator
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            dit:
         | 
| 6 | 
            +
              model:
         | 
| 7 | 
            +
                __inherit__: humo/configs/models/Wan_14B_I2V.yaml
         | 
| 8 | 
            +
                __object__:
         | 
| 9 | 
            +
                  path: humo.models.wan_modules.model_humo
         | 
| 10 | 
            +
                  name: WanModel
         | 
| 11 | 
            +
                insert_audio: True
         | 
| 12 | 
            +
              zero_vae_path: ./weights/HuMo/zero_vae_129frame.pt
         | 
| 13 | 
            +
              zero_vae_720p_path: ./weights/HuMo/zero_vae_720p_161frame.pt
         | 
| 14 | 
            +
              checkpoint_dir: ./weights/HuMo/HuMo-17B
         | 
| 15 | 
            +
              compile: False
         | 
| 16 | 
            +
              init_with_meta_device: True
         | 
| 17 | 
            +
              gradient_checkpoint: True
         | 
| 18 | 
            +
              fsdp:
         | 
| 19 | 
            +
                sharding_strategy: _HYBRID_SHARD_ZERO2
         | 
| 20 | 
            +
              sp_size: 1
         | 
| 21 | 
            +
              dtype: bfloat16
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            vae:
         | 
| 24 | 
            +
              checkpoint: ./weights/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth
         | 
| 25 | 
            +
              vae_stride: [ 4, 8, 8 ]
         | 
| 26 | 
            +
              scaling_factor: 0.9152
         | 
| 27 | 
            +
              compile: False
         | 
| 28 | 
            +
              grouping: True
         | 
| 29 | 
            +
              use_sample: False
         | 
| 30 | 
            +
              dtype: bfloat16
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            text:
         | 
| 33 | 
            +
              t5_checkpoint: ./weights/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth
         | 
| 34 | 
            +
              t5_tokenizer: ./weights/Wan2.1-T2V-1.3B/google/umt5-xxl
         | 
| 35 | 
            +
              dropout: 0.1
         | 
| 36 | 
            +
              dtype: bfloat16
         | 
| 37 | 
            +
              fsdp:
         | 
| 38 | 
            +
                enabled: True
         | 
| 39 | 
            +
                sharding_strategy: HYBRID_SHARD
         | 
| 40 | 
            +
             | 
| 41 | 
            +
            diffusion:
         | 
| 42 | 
            +
              schedule:
         | 
| 43 | 
            +
                type: lerp
         | 
| 44 | 
            +
                T: 1000.0
         | 
| 45 | 
            +
              sampler:
         | 
| 46 | 
            +
                type: euler
         | 
| 47 | 
            +
                prediction_type: v_lerp
         | 
| 48 | 
            +
              timesteps:
         | 
| 49 | 
            +
                training:
         | 
| 50 | 
            +
                  type: logitnormal
         | 
| 51 | 
            +
                  loc: 0.0
         | 
| 52 | 
            +
                  scale: 1.0
         | 
| 53 | 
            +
                sampling:
         | 
| 54 | 
            +
                  type: uniform_trailing
         | 
| 55 | 
            +
                  steps: 50
         | 
| 56 | 
            +
                  shift: 5.0
         | 
| 57 | 
            +
             | 
| 58 | 
            +
            audio:
         | 
| 59 | 
            +
              vocal_separator: ./weights/HuMo/audio_separator/Kim_Vocal_2.onnx
         | 
| 60 | 
            +
              wav2vec_model: ./weights/whisper-large-v3
         | 
| 61 | 
            +
             | 
| 62 | 
            +
            generation:
         | 
| 63 | 
            +
              mode: "TIA"  # TA, TIA
         | 
| 64 | 
            +
              extract_audio_feat: True
         | 
| 65 | 
            +
              seed: 666666
         | 
| 66 | 
            +
              frames: 97
         | 
| 67 | 
            +
              fps: 25
         | 
| 68 | 
            +
              height: 480 # 720 480
         | 
| 69 | 
            +
              width: 832 # 1280 832
         | 
| 70 | 
            +
              batch_size: 1
         | 
| 71 | 
            +
              sequence_parallel: 8
         | 
| 72 | 
            +
              output:
         | 
| 73 | 
            +
                dir: ./output
         | 
| 74 | 
            +
              # positive_prompt: ./examples/test_case.json
         | 
| 75 | 
            +
              sample_neg_prompt: '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
         | 
| 76 | 
            +
              scale_a: 5.5
         | 
| 77 | 
            +
              scale_t: 5.0
         | 
| 78 | 
            +
              step_change: 980
         | 
    	
        humo/configs/inference/generate_1_7B.yaml
    ADDED
    
    | @@ -0,0 +1,76 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            __object__:
         | 
| 2 | 
            +
              path: humo.generate_1_7B
         | 
| 3 | 
            +
              name: Generator
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            dit:
         | 
| 6 | 
            +
              model:
         | 
| 7 | 
            +
                __inherit__: humo/configs/models/Wan_1.3B.yaml
         | 
| 8 | 
            +
                __object__:
         | 
| 9 | 
            +
                  path: humo.models.wan_modules.model_humo
         | 
| 10 | 
            +
                  name: WanModel
         | 
| 11 | 
            +
                insert_audio: True
         | 
| 12 | 
            +
              zero_vae_path: ./weights/HuMo/zero_vae_129frame.pt
         | 
| 13 | 
            +
              zero_vae_720p_path: ./weights/HuMo/zero_vae_720p_161frame.pt
         | 
| 14 | 
            +
              checkpoint_dir: ./weights/HuMo/HuMo-1.7B/ema.pth #./weights/HuMo/HuMo-17B
         | 
| 15 | 
            +
              compile: False
         | 
| 16 | 
            +
              init_with_meta_device: True
         | 
| 17 | 
            +
              gradient_checkpoint: True
         | 
| 18 | 
            +
              fsdp:
         | 
| 19 | 
            +
                sharding_strategy: _HYBRID_SHARD_ZERO2
         | 
| 20 | 
            +
              sp_size: 1
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            vae:
         | 
| 23 | 
            +
              checkpoint: ./weights/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth
         | 
| 24 | 
            +
              vae_stride: [ 4, 8, 8 ]
         | 
| 25 | 
            +
              scaling_factor: 0.9152
         | 
| 26 | 
            +
              compile: False
         | 
| 27 | 
            +
              grouping: True
         | 
| 28 | 
            +
              use_sample: False
         | 
| 29 | 
            +
              dtype: bfloat16
         | 
| 30 | 
            +
             | 
| 31 | 
            +
            text:
         | 
| 32 | 
            +
              t5_checkpoint: ./weights/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth
         | 
| 33 | 
            +
              t5_tokenizer: ./weights/Wan2.1-T2V-1.3B/google/umt5-xxl
         | 
| 34 | 
            +
              dropout: 0.1
         | 
| 35 | 
            +
              dtype: bfloat16
         | 
| 36 | 
            +
              fsdp:
         | 
| 37 | 
            +
                enabled: True
         | 
| 38 | 
            +
                sharding_strategy: HYBRID_SHARD
         | 
| 39 | 
            +
             | 
| 40 | 
            +
            diffusion:
         | 
| 41 | 
            +
              schedule:
         | 
| 42 | 
            +
                type: lerp
         | 
| 43 | 
            +
                T: 1000.0
         | 
| 44 | 
            +
              sampler:
         | 
| 45 | 
            +
                type: euler
         | 
| 46 | 
            +
                prediction_type: v_lerp
         | 
| 47 | 
            +
              timesteps:
         | 
| 48 | 
            +
                training:
         | 
| 49 | 
            +
                  type: logitnormal
         | 
| 50 | 
            +
                  loc: 0.0
         | 
| 51 | 
            +
                  scale: 1.0
         | 
| 52 | 
            +
                sampling:
         | 
| 53 | 
            +
                  type: uniform_trailing
         | 
| 54 | 
            +
                  steps: 50
         | 
| 55 | 
            +
                  shift: 5.0
         | 
| 56 | 
            +
             | 
| 57 | 
            +
            audio:
         | 
| 58 | 
            +
              vocal_separator: ./weights/audio_separator/Kim_Vocal_2.onnx
         | 
| 59 | 
            +
              wav2vec_model: ./weights/whisper-large-v3
         | 
| 60 | 
            +
             | 
| 61 | 
            +
            generation:
         | 
| 62 | 
            +
              mode: "TIA"  # TA, TIA
         | 
| 63 | 
            +
              extract_audio_feat: True
         | 
| 64 | 
            +
              seed: 666666
         | 
| 65 | 
            +
              frames: 97
         | 
| 66 | 
            +
              fps: 25
         | 
| 67 | 
            +
              height: 720 # 480
         | 
| 68 | 
            +
              width: 1280 # 832
         | 
| 69 | 
            +
              batch_size: 1
         | 
| 70 | 
            +
              output:
         | 
| 71 | 
            +
                dir: ./output
         | 
| 72 | 
            +
              sample_neg_prompt: '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
         | 
| 73 | 
            +
              scale_t: 7.5
         | 
| 74 | 
            +
              scale_i: 4.0
         | 
| 75 | 
            +
              scale_a: 7.5
         | 
| 76 | 
            +
              # step_change: 980
         | 
    	
        humo/configs/models/Wan_1.3B.yaml
    ADDED
    
    | @@ -0,0 +1,17 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            __object__:
         | 
| 2 | 
            +
              path: ???
         | 
| 3 | 
            +
              name: ???
         | 
| 4 | 
            +
              args: as_params
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            text_len: 512
         | 
| 7 | 
            +
            patch_size: [ 1, 2, 2 ]
         | 
| 8 | 
            +
            dim: 1536
         | 
| 9 | 
            +
            ffn_dim: 8960
         | 
| 10 | 
            +
            freq_dim: 256
         | 
| 11 | 
            +
            model_type: "t2v"
         | 
| 12 | 
            +
            num_heads: 12
         | 
| 13 | 
            +
            num_layers: 30
         | 
| 14 | 
            +
            window_size: [ -1, -1 ]
         | 
| 15 | 
            +
            qk_norm: True
         | 
| 16 | 
            +
            cross_attn_norm: True
         | 
| 17 | 
            +
            eps: 1e-6
         | 
    	
        humo/configs/models/Wan_1.3B_I2V.yaml
    ADDED
    
    | @@ -0,0 +1,18 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            __object__:
         | 
| 2 | 
            +
              path: ???
         | 
| 3 | 
            +
              name: ???
         | 
| 4 | 
            +
              args: as_params
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            text_len: 512
         | 
| 7 | 
            +
            patch_size: [ 1, 2, 2 ]
         | 
| 8 | 
            +
            dim: 1536
         | 
| 9 | 
            +
            ffn_dim: 8960
         | 
| 10 | 
            +
            freq_dim: 256
         | 
| 11 | 
            +
            in_dim: 36
         | 
| 12 | 
            +
            model_type: "i2v"
         | 
| 13 | 
            +
            num_heads: 12
         | 
| 14 | 
            +
            num_layers: 30
         | 
| 15 | 
            +
            window_size: [ -1, -1 ]
         | 
| 16 | 
            +
            qk_norm: True
         | 
| 17 | 
            +
            cross_attn_norm: True
         | 
| 18 | 
            +
            eps: 1e-6
         | 
    	
        humo/configs/models/Wan_14B.yaml
    ADDED
    
    | @@ -0,0 +1,17 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            __object__:
         | 
| 2 | 
            +
              path: ???
         | 
| 3 | 
            +
              name: ???
         | 
| 4 | 
            +
              args: as_params
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            text_len: 512
         | 
| 7 | 
            +
            patch_size: [ 1, 2, 2 ]
         | 
| 8 | 
            +
            dim: 5120
         | 
| 9 | 
            +
            ffn_dim: 13824
         | 
| 10 | 
            +
            freq_dim: 256
         | 
| 11 | 
            +
            model_type: "t2v"
         | 
| 12 | 
            +
            num_heads: 40
         | 
| 13 | 
            +
            num_layers: 40
         | 
| 14 | 
            +
            window_size: [ -1, -1 ]
         | 
| 15 | 
            +
            qk_norm: True
         | 
| 16 | 
            +
            cross_attn_norm: True
         | 
| 17 | 
            +
            eps: 1e-6
         | 
    	
        humo/configs/models/Wan_14B_I2V.yaml
    ADDED
    
    | @@ -0,0 +1,18 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            __object__:
         | 
| 2 | 
            +
              path: ???
         | 
| 3 | 
            +
              name: ???
         | 
| 4 | 
            +
              args: as_params
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            text_len: 512
         | 
| 7 | 
            +
            patch_size: [ 1, 2, 2 ]
         | 
| 8 | 
            +
            dim: 5120
         | 
| 9 | 
            +
            ffn_dim: 13824
         | 
| 10 | 
            +
            freq_dim: 256
         | 
| 11 | 
            +
            in_dim: 36
         | 
| 12 | 
            +
            model_type: "i2v"
         | 
| 13 | 
            +
            num_heads: 40
         | 
| 14 | 
            +
            num_layers: 40
         | 
| 15 | 
            +
            window_size: [ -1, -1 ]
         | 
| 16 | 
            +
            qk_norm: True
         | 
| 17 | 
            +
            cross_attn_norm: True
         | 
| 18 | 
            +
            eps: 1e-6
         | 
    	
        humo/generate.py
    ADDED
    
    | @@ -0,0 +1,984 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
         | 
| 2 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 3 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 4 | 
            +
            # You may obtain a copy of the License at
         | 
| 5 | 
            +
            # http://www.apache.org/licenses/LICENSE-2.0
         | 
| 6 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 7 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 8 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 9 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 10 | 
            +
            # limitations under the License.
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            # Inference codes adapted from [SeedVR]
         | 
| 13 | 
            +
            # https://github.com/ByteDance-Seed/SeedVR/blob/main/projects/inference_seedvr2_7b.py
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            import math
         | 
| 16 | 
            +
            import os
         | 
| 17 | 
            +
            import gc
         | 
| 18 | 
            +
            import random
         | 
| 19 | 
            +
            import sys
         | 
| 20 | 
            +
            import mediapy
         | 
| 21 | 
            +
            import numpy as np
         | 
| 22 | 
            +
            import torch
         | 
| 23 | 
            +
            import torch.distributed as dist
         | 
| 24 | 
            +
            from omegaconf import DictConfig, ListConfig, OmegaConf
         | 
| 25 | 
            +
            from einops import rearrange
         | 
| 26 | 
            +
            from omegaconf import OmegaConf
         | 
| 27 | 
            +
            from PIL import Image, ImageOps
         | 
| 28 | 
            +
            from torchvision.transforms import ToTensor
         | 
| 29 | 
            +
            from tqdm import tqdm
         | 
| 30 | 
            +
            from torch.distributed.device_mesh import init_device_mesh
         | 
| 31 | 
            +
            from torch.distributed.fsdp import (
         | 
| 32 | 
            +
                BackwardPrefetch,
         | 
| 33 | 
            +
                FullyShardedDataParallel,
         | 
| 34 | 
            +
                MixedPrecision,
         | 
| 35 | 
            +
                ShardingStrategy,
         | 
| 36 | 
            +
            )
         | 
| 37 | 
            +
            from common.distributed import (
         | 
| 38 | 
            +
                get_device,
         | 
| 39 | 
            +
                get_global_rank,
         | 
| 40 | 
            +
                get_local_rank,
         | 
| 41 | 
            +
                meta_param_init_fn,
         | 
| 42 | 
            +
                meta_non_persistent_buffer_init_fn,
         | 
| 43 | 
            +
                init_torch,
         | 
| 44 | 
            +
            )
         | 
| 45 | 
            +
            from common.distributed.advanced import (
         | 
| 46 | 
            +
                init_unified_parallel,
         | 
| 47 | 
            +
                get_unified_parallel_world_size,
         | 
| 48 | 
            +
                get_sequence_parallel_rank,
         | 
| 49 | 
            +
                init_model_shard_cpu_group,
         | 
| 50 | 
            +
            )
         | 
| 51 | 
            +
            from common.logger import get_logger
         | 
| 52 | 
            +
            from common.config import create_object
         | 
| 53 | 
            +
            from common.distributed import get_device, get_global_rank
         | 
| 54 | 
            +
            from torchvision.transforms import Compose, Normalize, ToTensor
         | 
| 55 | 
            +
            from humo.models.wan_modules.t5 import T5EncoderModel
         | 
| 56 | 
            +
            from humo.models.wan_modules.vae import WanVAE
         | 
| 57 | 
            +
            from humo.models.utils.utils import tensor_to_video, prepare_json_dataset
         | 
| 58 | 
            +
            from contextlib import contextmanager
         | 
| 59 | 
            +
            import torch.cuda.amp as amp
         | 
| 60 | 
            +
            from humo.models.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
         | 
| 61 | 
            +
            from humo.utils.audio_processor_whisper import AudioProcessor
         | 
| 62 | 
            +
            from humo.utils.wav2vec import linear_interpolation_fps
         | 
| 63 | 
            +
            from torchao.quantization import quantize_
         | 
| 64 | 
            +
             | 
| 65 | 
            +
            import torch._dynamo as dynamo
         | 
| 66 | 
            +
            dynamo.config.capture_scalar_outputs = True
         | 
| 67 | 
            +
            torch.set_float32_matmul_precision("high")
         | 
| 68 | 
            +
             | 
| 69 | 
            +
            import torch
         | 
| 70 | 
            +
            import torch.nn as nn
         | 
| 71 | 
            +
            import transformer_engine.pytorch as te
         | 
| 72 | 
            +
             | 
| 73 | 
            +
            image_transform = Compose([
         | 
| 74 | 
            +
                ToTensor(),
         | 
| 75 | 
            +
                Normalize(mean=0.5, std=0.5),
         | 
| 76 | 
            +
            ])
         | 
| 77 | 
            +
             | 
| 78 | 
            +
            SIZE_CONFIGS = {
         | 
| 79 | 
            +
                '720*1280': (720, 1280),
         | 
| 80 | 
            +
                '1280*720': (1280, 720),
         | 
| 81 | 
            +
                '480*832': (480, 832),
         | 
| 82 | 
            +
                '832*480': (832, 480),
         | 
| 83 | 
            +
                '1024*1024': (1024, 1024),
         | 
| 84 | 
            +
            }
         | 
| 85 | 
            +
             | 
| 86 | 
            +
            def clever_format(nums, format="%.2f"):
         | 
| 87 | 
            +
                from typing import Iterable
         | 
| 88 | 
            +
                if not isinstance(nums, Iterable):
         | 
| 89 | 
            +
                    nums = [nums]
         | 
| 90 | 
            +
                clever_nums = []
         | 
| 91 | 
            +
                for num in nums:
         | 
| 92 | 
            +
                    if num > 1e12:
         | 
| 93 | 
            +
                        clever_nums.append(format % (num / 1e12) + "T")
         | 
| 94 | 
            +
                    elif num > 1e9:
         | 
| 95 | 
            +
                        clever_nums.append(format % (num / 1e9) + "G")
         | 
| 96 | 
            +
                    elif num > 1e6:
         | 
| 97 | 
            +
                        clever_nums.append(format % (num / 1e6) + "M")
         | 
| 98 | 
            +
                    elif num > 1e3:
         | 
| 99 | 
            +
                        clever_nums.append(format % (num / 1e3) + "K")
         | 
| 100 | 
            +
                    else:
         | 
| 101 | 
            +
                        clever_nums.append(format % num + "B")
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                clever_nums = clever_nums[0] if len(clever_nums) == 1 else (*clever_nums,)
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                return clever_nums
         | 
| 106 | 
            +
             | 
| 107 | 
            +
             | 
| 108 | 
            +
                
         | 
| 109 | 
            +
            # --- put near your imports ---
         | 
| 110 | 
            +
            import torch
         | 
| 111 | 
            +
            import torch.nn as nn
         | 
| 112 | 
            +
            import contextlib
         | 
| 113 | 
            +
            import transformer_engine.pytorch as te
         | 
| 114 | 
            +
             | 
| 115 | 
            +
            # FP8 autocast compatibility for different TE versions
         | 
| 116 | 
            +
            try:
         | 
| 117 | 
            +
                # Preferred modern API
         | 
| 118 | 
            +
                from transformer_engine.pytorch import fp8_autocast
         | 
| 119 | 
            +
                try:
         | 
| 120 | 
            +
                    # Newer TE: use recipe-based API
         | 
| 121 | 
            +
                    from transformer_engine.common.recipe import DelayedScaling, Format
         | 
| 122 | 
            +
                    def make_fp8_ctx(enabled: bool = True):
         | 
| 123 | 
            +
                        if not enabled:
         | 
| 124 | 
            +
                            return contextlib.nullcontext()
         | 
| 125 | 
            +
                        fp8_recipe = DelayedScaling(fp8_format=Format.E4M3)  # E4M3 format
         | 
| 126 | 
            +
                        return fp8_autocast(enabled=True, fp8_recipe=fp8_recipe)
         | 
| 127 | 
            +
                except Exception:
         | 
| 128 | 
            +
                    # Very old variant that might still accept fp8_format directly
         | 
| 129 | 
            +
                    def make_fp8_ctx(enabled: bool = True):
         | 
| 130 | 
            +
                        # If TE doesn't have FP8Format, just no-op
         | 
| 131 | 
            +
                        if not hasattr(te, "FP8Format"):
         | 
| 132 | 
            +
                            return contextlib.nullcontext()
         | 
| 133 | 
            +
                        return te.fp8_autocast(enabled=enabled, fp8_format=te.FP8Format.E4M3)
         | 
| 134 | 
            +
            except Exception:
         | 
| 135 | 
            +
                # TE not present or totally incompatible — no-op
         | 
| 136 | 
            +
                def make_fp8_ctx(enabled: bool = True):
         | 
| 137 | 
            +
                    return contextlib.nullcontext()
         | 
| 138 | 
            +
             | 
| 139 | 
            +
             | 
| 140 | 
            +
            # TE sometimes exposes Linear at different paths; this normalizes it.
         | 
| 141 | 
            +
            try:
         | 
| 142 | 
            +
                TELinear = te.Linear
         | 
| 143 | 
            +
            except AttributeError:  # very old layouts
         | 
| 144 | 
            +
                from transformer_engine.pytorch.modules.linear import Linear as TELinear  # type: ignore
         | 
| 145 | 
            +
             | 
| 146 | 
            +
            # --- near imports ---
         | 
| 147 | 
            +
            import torch
         | 
| 148 | 
            +
            import torch.nn as nn
         | 
| 149 | 
            +
            import transformer_engine.pytorch as te
         | 
| 150 | 
            +
             | 
| 151 | 
            +
            try:
         | 
| 152 | 
            +
                TELinear = te.Linear
         | 
| 153 | 
            +
            except AttributeError:
         | 
| 154 | 
            +
                from transformer_engine.pytorch.modules.linear import Linear as TELinear  # type: ignore
         | 
| 155 | 
            +
             | 
| 156 | 
            +
            import torch
         | 
| 157 | 
            +
            import torch.nn as nn
         | 
| 158 | 
            +
            import transformer_engine.pytorch as te
         | 
| 159 | 
            +
             | 
| 160 | 
            +
            try:
         | 
| 161 | 
            +
                TELinear = te.Linear
         | 
| 162 | 
            +
            except AttributeError:
         | 
| 163 | 
            +
                from transformer_engine.pytorch.modules.linear import Linear as TELinear  # type: ignore
         | 
| 164 | 
            +
             | 
| 165 | 
            +
            def _default_te_allow(fullname: str, lin: nn.Linear) -> bool:
         | 
| 166 | 
            +
                """
         | 
| 167 | 
            +
                Allow TE only where it's shape-safe & beneficial.
         | 
| 168 | 
            +
                Skip small/special layers (time/timestep/pos embeds, heads).
         | 
| 169 | 
            +
                Enforce multiples of 16 for in/out features (FP8 kernel friendly).
         | 
| 170 | 
            +
                Also skip very small projections likely to see M=1.
         | 
| 171 | 
            +
                """
         | 
| 172 | 
            +
                blocked_keywords = (
         | 
| 173 | 
            +
                    "time_embedding", "timestep", "time_embed",
         | 
| 174 | 
            +
                    "time_projection", "pos_embedding", "pos_embed",
         | 
| 175 | 
            +
                    "to_logits", "logits", "final_proj", "proj_out", "output_projection",
         | 
| 176 | 
            +
                )
         | 
| 177 | 
            +
                if any(k in fullname for k in blocked_keywords):
         | 
| 178 | 
            +
                    return False
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                # TE FP8 kernels like K, N divisible by 16
         | 
| 181 | 
            +
                if lin.in_features % 16 != 0 or lin.out_features % 16 != 0:
         | 
| 182 | 
            +
                    return False
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                # Heuristic: avoid tiny layers; keeps attention/MLP, skips small MLPs
         | 
| 185 | 
            +
                if lin.in_features < 512 or lin.out_features < 512:
         | 
| 186 | 
            +
                    return False
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                # Whitelist: only convert inside transformer blocks if you know their prefix
         | 
| 189 | 
            +
                # This further reduces risk of catching special heads elsewhere.
         | 
| 190 | 
            +
                allowed_context = ("blocks", "layers", "transformer", "attn", "mlp", "ffn")
         | 
| 191 | 
            +
                if not any(tok in fullname for tok in allowed_context):
         | 
| 192 | 
            +
                    return False
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                return True
         | 
| 195 | 
            +
             | 
| 196 | 
            +
            @torch.no_grad()
         | 
| 197 | 
            +
            def convert_linears_to_te_fp8(module: nn.Module, allow_pred=_default_te_allow, _prefix=""):
         | 
| 198 | 
            +
                for name, child in list(module.named_children()):
         | 
| 199 | 
            +
                    full = f"{_prefix}.{name}" if _prefix else name
         | 
| 200 | 
            +
                    convert_linears_to_te_fp8(child, allow_pred, full)
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                    if isinstance(child, nn.Linear):
         | 
| 203 | 
            +
                        if allow_pred is not None and not allow_pred(full, child):
         | 
| 204 | 
            +
                            continue
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                        te_lin = TELinear(
         | 
| 207 | 
            +
                            in_features=child.in_features,
         | 
| 208 | 
            +
                            out_features=child.out_features,
         | 
| 209 | 
            +
                            bias=(child.bias is not None),
         | 
| 210 | 
            +
                            params_dtype=torch.bfloat16,
         | 
| 211 | 
            +
                        ).to(child.weight.device)
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                        te_lin.weight.copy_(child.weight.to(te_lin.weight.dtype))
         | 
| 214 | 
            +
                        if child.bias is not None:
         | 
| 215 | 
            +
                            te_lin.bias.copy_(child.bias.to(te_lin.bias.dtype))
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                        setattr(module, name, te_lin)
         | 
| 218 | 
            +
                return module
         | 
| 219 | 
            +
             | 
| 220 | 
            +
            class Generator():
         | 
| 221 | 
            +
                def __init__(self, config: DictConfig):
         | 
| 222 | 
            +
                    self.config = config.copy()
         | 
| 223 | 
            +
                    OmegaConf.set_readonly(self.config, True)
         | 
| 224 | 
            +
                    self.logger = get_logger(self.__class__.__name__)
         | 
| 225 | 
            +
                    
         | 
| 226 | 
            +
                    # init_torch(cudnn_benchmark=False)
         | 
| 227 | 
            +
                    self.configure_models()
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                def entrypoint(self):
         | 
| 230 | 
            +
                    
         | 
| 231 | 
            +
                    self.inference_loop()
         | 
| 232 | 
            +
                
         | 
| 233 | 
            +
                def get_fsdp_sharding_config(self, sharding_strategy, device_mesh_config):
         | 
| 234 | 
            +
                    device_mesh = None
         | 
| 235 | 
            +
                    fsdp_strategy = ShardingStrategy[sharding_strategy]
         | 
| 236 | 
            +
                    if (
         | 
| 237 | 
            +
                        fsdp_strategy in [ShardingStrategy._HYBRID_SHARD_ZERO2, ShardingStrategy.HYBRID_SHARD]
         | 
| 238 | 
            +
                        and device_mesh_config is not None
         | 
| 239 | 
            +
                    ):
         | 
| 240 | 
            +
                        device_mesh = init_device_mesh("cuda", tuple(device_mesh_config))
         | 
| 241 | 
            +
                    return device_mesh, fsdp_strategy
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                    
         | 
| 244 | 
            +
                def configure_models(self):
         | 
| 245 | 
            +
                    self.configure_dit_model(device="cuda")
         | 
| 246 | 
            +
             | 
| 247 | 
            +
                    self.dit.eval().to("cuda")
         | 
| 248 | 
            +
                    convert_linears_to_te_fp8(self.dit)
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                    self.dit = torch.compile(self.dit, )
         | 
| 251 | 
            +
             | 
| 252 | 
            +
             | 
| 253 | 
            +
                    self.configure_vae_model(device="cuda")
         | 
| 254 | 
            +
                    if self.config.generation.get('extract_audio_feat', False):
         | 
| 255 | 
            +
                        self.configure_wav2vec(device="cpu")
         | 
| 256 | 
            +
                    self.configure_text_model(device="cuda")
         | 
| 257 | 
            +
             | 
| 258 | 
            +
                    # # Initialize fsdp.
         | 
| 259 | 
            +
                    # self.configure_dit_fsdp_model()
         | 
| 260 | 
            +
                    # self.configure_text_fsdp_model()
         | 
| 261 | 
            +
             | 
| 262 | 
            +
                    # quantize_(self.text_encoder, Int8WeightOnlyConfig())
         | 
| 263 | 
            +
                    # quantize_(self.dit, Float8DynamicActivationFloat8WeightConfig())
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                
         | 
| 266 | 
            +
                def configure_dit_model(self, device=get_device()):
         | 
| 267 | 
            +
             | 
| 268 | 
            +
                    init_unified_parallel(self.config.dit.sp_size)
         | 
| 269 | 
            +
                    self.sp_size = get_unified_parallel_world_size()
         | 
| 270 | 
            +
             | 
| 271 | 
            +
                    # Create DiT model on meta, then mark dtype as bfloat16 (no real allocation yet).
         | 
| 272 | 
            +
                    init_device = "meta"
         | 
| 273 | 
            +
                    with torch.device(init_device):
         | 
| 274 | 
            +
                        self.dit = create_object(self.config.dit.model)
         | 
| 275 | 
            +
                        self.dit = self.dit.to(dtype=torch.bfloat16)  # or: self.dit.bfloat16()
         | 
| 276 | 
            +
                    self.logger.info(f"Load DiT model on {init_device}.")
         | 
| 277 | 
            +
                    self.dit.eval().requires_grad_(False)
         | 
| 278 | 
            +
             | 
| 279 | 
            +
                    # Load dit checkpoint.
         | 
| 280 | 
            +
                    path = self.config.dit.checkpoint_dir
         | 
| 281 | 
            +
             | 
| 282 | 
            +
                    def _cast_state_dict_to_bf16(state):
         | 
| 283 | 
            +
                        for k, v in state.items():
         | 
| 284 | 
            +
                            if isinstance(v, torch.Tensor) and v.is_floating_point():
         | 
| 285 | 
            +
                                state[k] = v.to(dtype=torch.bfloat16, copy=False)
         | 
| 286 | 
            +
                        return state
         | 
| 287 | 
            +
             | 
| 288 | 
            +
                    if path.endswith(".pth"):
         | 
| 289 | 
            +
                        # Load to CPU first; we’ll move the model later.
         | 
| 290 | 
            +
                        state = torch.load(path, map_location="cpu", mmap=True)
         | 
| 291 | 
            +
                        state = _cast_state_dict_to_bf16(state)
         | 
| 292 | 
            +
                        missing_keys, unexpected_keys = self.dit.load_state_dict(state, strict=False, assign=True)
         | 
| 293 | 
            +
                        self.logger.info(
         | 
| 294 | 
            +
                            f"dit loaded from {path}. Missing keys: {len(missing_keys)}, Unexpected keys: {len(unexpected_keys)}"
         | 
| 295 | 
            +
                        )
         | 
| 296 | 
            +
                    else:
         | 
| 297 | 
            +
                        from safetensors.torch import load_file
         | 
| 298 | 
            +
                        import json
         | 
| 299 | 
            +
                        def load_custom_sharded_weights(model_dir, base_name):
         | 
| 300 | 
            +
                            index_path = f"{model_dir}/{base_name}.safetensors.index.json"
         | 
| 301 | 
            +
                            with open(index_path, "r") as f:
         | 
| 302 | 
            +
                                index = json.load(f)
         | 
| 303 | 
            +
                            weight_map = index["weight_map"]
         | 
| 304 | 
            +
                            shard_files = set(weight_map.values())
         | 
| 305 | 
            +
                            state_dict = {}
         | 
| 306 | 
            +
                            for shard_file in shard_files:
         | 
| 307 | 
            +
                                shard_path = f"{model_dir}/{shard_file}"
         | 
| 308 | 
            +
                                # Load on CPU, then cast to bf16; we’ll move the whole module later.
         | 
| 309 | 
            +
                                shard_state = load_file(shard_path, device="cpu")
         | 
| 310 | 
            +
                                shard_state = {k: (v.to(dtype=torch.bfloat16, copy=False) if v.is_floating_point() else v)
         | 
| 311 | 
            +
                                            for k, v in shard_state.items()}
         | 
| 312 | 
            +
                                state_dict.update(shard_state)
         | 
| 313 | 
            +
                            return state_dict
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                        state = load_custom_sharded_weights(path, 'humo')
         | 
| 316 | 
            +
                        self.dit.load_state_dict(state, strict=False, assign=True)
         | 
| 317 | 
            +
             | 
| 318 | 
            +
                    self.dit = meta_non_persistent_buffer_init_fn(self.dit)
         | 
| 319 | 
            +
             | 
| 320 | 
            +
                    target_device = get_device() if device in [get_device(), "cuda"] else device
         | 
| 321 | 
            +
                    self.dit.to(target_device)  # dtype already bf16
         | 
| 322 | 
            +
             | 
| 323 | 
            +
                    # Print model size.
         | 
| 324 | 
            +
                    params = sum(p.numel() for p in self.dit.parameters())
         | 
| 325 | 
            +
                    self.logger.info(
         | 
| 326 | 
            +
                        f"[RANK:{get_global_rank()}] DiT Parameters: {clever_format(params, '%.3f')}"
         | 
| 327 | 
            +
                    )
         | 
| 328 | 
            +
             | 
| 329 | 
            +
                    
         | 
| 330 | 
            +
                def configure_vae_model(self, device=get_device()):
         | 
| 331 | 
            +
                    self.vae_stride = self.config.vae.vae_stride
         | 
| 332 | 
            +
                    self.vae = WanVAE(
         | 
| 333 | 
            +
                        vae_pth=self.config.vae.checkpoint,
         | 
| 334 | 
            +
                        device=device)
         | 
| 335 | 
            +
                    
         | 
| 336 | 
            +
                    if self.config.generation.height == 480:
         | 
| 337 | 
            +
                        self.zero_vae = torch.load(self.config.dit.zero_vae_path)
         | 
| 338 | 
            +
                    elif self.config.generation.height == 720:
         | 
| 339 | 
            +
                        self.zero_vae = torch.load(self.config.dit.zero_vae_720p_path)
         | 
| 340 | 
            +
                    else:
         | 
| 341 | 
            +
                        raise ValueError(f"Unsupported height {self.config.generation.height} for zero-vae.")
         | 
| 342 | 
            +
                
         | 
| 343 | 
            +
                def configure_wav2vec(self, device=get_device()):
         | 
| 344 | 
            +
                    audio_separator_model_file = self.config.audio.vocal_separator
         | 
| 345 | 
            +
                    wav2vec_model_path = self.config.audio.wav2vec_model
         | 
| 346 | 
            +
             | 
| 347 | 
            +
                    self.audio_processor = AudioProcessor(
         | 
| 348 | 
            +
                        16000,
         | 
| 349 | 
            +
                        25,
         | 
| 350 | 
            +
                        wav2vec_model_path,
         | 
| 351 | 
            +
                        "all",
         | 
| 352 | 
            +
                        audio_separator_model_file,
         | 
| 353 | 
            +
                        None,  # not seperate
         | 
| 354 | 
            +
                        os.path.join(self.config.generation.output.dir, "vocals"),
         | 
| 355 | 
            +
                        device=device,
         | 
| 356 | 
            +
                    )
         | 
| 357 | 
            +
             | 
| 358 | 
            +
                def configure_text_model(self, device=get_device()):
         | 
| 359 | 
            +
                    self.text_encoder = T5EncoderModel(
         | 
| 360 | 
            +
                        text_len=self.config.dit.model.text_len,
         | 
| 361 | 
            +
                        dtype=torch.bfloat16,
         | 
| 362 | 
            +
                        device=device,
         | 
| 363 | 
            +
                        checkpoint_path=self.config.text.t5_checkpoint,
         | 
| 364 | 
            +
                        tokenizer_path=self.config.text.t5_tokenizer,
         | 
| 365 | 
            +
                        )
         | 
| 366 | 
            +
             | 
| 367 | 
            +
                
         | 
| 368 | 
            +
                def configure_dit_fsdp_model(self):
         | 
| 369 | 
            +
                    from humo.models.wan_modules.model_humo import WanAttentionBlock
         | 
| 370 | 
            +
             | 
| 371 | 
            +
                    dit_blocks = (WanAttentionBlock,)
         | 
| 372 | 
            +
             | 
| 373 | 
            +
                    # Init model_shard_cpu_group for saving checkpoint with sharded state_dict.
         | 
| 374 | 
            +
                    init_model_shard_cpu_group(
         | 
| 375 | 
            +
                        self.config.dit.fsdp.sharding_strategy,
         | 
| 376 | 
            +
                        self.config.dit.fsdp.get("device_mesh", None),
         | 
| 377 | 
            +
                    )
         | 
| 378 | 
            +
             | 
| 379 | 
            +
                    # Assert that dit has wrappable blocks.
         | 
| 380 | 
            +
                    assert any(isinstance(m, dit_blocks) for m in self.dit.modules())
         | 
| 381 | 
            +
             | 
| 382 | 
            +
                    # Define wrap policy on all dit blocks.
         | 
| 383 | 
            +
                    def custom_auto_wrap_policy(module, recurse, *args, **kwargs):
         | 
| 384 | 
            +
                        return recurse or isinstance(module, dit_blocks)
         | 
| 385 | 
            +
             | 
| 386 | 
            +
                    # Configure FSDP settings.
         | 
| 387 | 
            +
                    device_mesh, fsdp_strategy = self.get_fsdp_sharding_config(
         | 
| 388 | 
            +
                        self.config.dit.fsdp.sharding_strategy,
         | 
| 389 | 
            +
                        self.config.dit.fsdp.get("device_mesh", None),
         | 
| 390 | 
            +
                    )
         | 
| 391 | 
            +
                    settings = dict(
         | 
| 392 | 
            +
                        auto_wrap_policy=custom_auto_wrap_policy,
         | 
| 393 | 
            +
                        sharding_strategy=fsdp_strategy,
         | 
| 394 | 
            +
                        backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
         | 
| 395 | 
            +
                        device_id=get_local_rank(),
         | 
| 396 | 
            +
                        use_orig_params=False,
         | 
| 397 | 
            +
                        sync_module_states=True,
         | 
| 398 | 
            +
                        forward_prefetch=True,
         | 
| 399 | 
            +
                        limit_all_gathers=False,  # False for ZERO2.
         | 
| 400 | 
            +
                        mixed_precision=MixedPrecision(
         | 
| 401 | 
            +
                            param_dtype=torch.bfloat16,
         | 
| 402 | 
            +
                            reduce_dtype=torch.float32,
         | 
| 403 | 
            +
                            buffer_dtype=torch.float32,
         | 
| 404 | 
            +
                        ),
         | 
| 405 | 
            +
                        device_mesh=device_mesh,
         | 
| 406 | 
            +
                        param_init_fn=meta_param_init_fn,
         | 
| 407 | 
            +
                    )
         | 
| 408 | 
            +
             | 
| 409 | 
            +
                    # Apply FSDP.
         | 
| 410 | 
            +
                    self.dit = FullyShardedDataParallel(self.dit, **settings)
         | 
| 411 | 
            +
                    # self.dit.to(get_device())
         | 
| 412 | 
            +
             | 
| 413 | 
            +
             | 
| 414 | 
            +
                def configure_text_fsdp_model(self):
         | 
| 415 | 
            +
                    # If FSDP is not enabled, put text_encoder to GPU and return.
         | 
| 416 | 
            +
                    if not self.config.text.fsdp.enabled:
         | 
| 417 | 
            +
                        self.text_encoder.to(get_device())
         | 
| 418 | 
            +
                        return
         | 
| 419 | 
            +
             | 
| 420 | 
            +
                    # from transformers.models.t5.modeling_t5 import T5Block
         | 
| 421 | 
            +
                    from humo.models.wan_modules.t5 import T5SelfAttention
         | 
| 422 | 
            +
             | 
| 423 | 
            +
                    text_blocks = (torch.nn.Embedding, T5SelfAttention)
         | 
| 424 | 
            +
                    # text_blocks_names = ("QWenBlock", "QWenModel")  # QWen cannot be imported. Use str.
         | 
| 425 | 
            +
             | 
| 426 | 
            +
                    def custom_auto_wrap_policy(module, recurse, *args, **kwargs):
         | 
| 427 | 
            +
                        return (
         | 
| 428 | 
            +
                            recurse
         | 
| 429 | 
            +
                            or isinstance(module, text_blocks)
         | 
| 430 | 
            +
                        )
         | 
| 431 | 
            +
             | 
| 432 | 
            +
                    # Apply FSDP.
         | 
| 433 | 
            +
                    text_encoder_dtype = getattr(torch, self.config.text.dtype)
         | 
| 434 | 
            +
                    device_mesh, fsdp_strategy = self.get_fsdp_sharding_config(
         | 
| 435 | 
            +
                        self.config.text.fsdp.sharding_strategy,
         | 
| 436 | 
            +
                        self.config.text.fsdp.get("device_mesh", None),
         | 
| 437 | 
            +
                    )
         | 
| 438 | 
            +
                    self.text_encoder = FullyShardedDataParallel(
         | 
| 439 | 
            +
                        module=self.text_encoder,
         | 
| 440 | 
            +
                        auto_wrap_policy=custom_auto_wrap_policy,
         | 
| 441 | 
            +
                        sharding_strategy=fsdp_strategy,
         | 
| 442 | 
            +
                        backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
         | 
| 443 | 
            +
                        device_id=get_local_rank(),
         | 
| 444 | 
            +
                        use_orig_params=False,
         | 
| 445 | 
            +
                        sync_module_states=False,
         | 
| 446 | 
            +
                        forward_prefetch=True,
         | 
| 447 | 
            +
                        limit_all_gathers=True,
         | 
| 448 | 
            +
                        mixed_precision=MixedPrecision(
         | 
| 449 | 
            +
                            param_dtype=text_encoder_dtype,
         | 
| 450 | 
            +
                            reduce_dtype=text_encoder_dtype,
         | 
| 451 | 
            +
                            buffer_dtype=text_encoder_dtype,
         | 
| 452 | 
            +
                        ),
         | 
| 453 | 
            +
                        device_mesh=device_mesh,
         | 
| 454 | 
            +
                    )
         | 
| 455 | 
            +
                    self.text_encoder.to(get_device()).requires_grad_(False)
         | 
| 456 | 
            +
             | 
| 457 | 
            +
             | 
| 458 | 
            +
                def load_image_latent_ref_id(self, path: str, size, device):
         | 
| 459 | 
            +
                    # Load size.
         | 
| 460 | 
            +
                    h, w = size[1], size[0]
         | 
| 461 | 
            +
             | 
| 462 | 
            +
                    # Load image.
         | 
| 463 | 
            +
                    if len(path) > 1 and not isinstance(path, str):
         | 
| 464 | 
            +
                        ref_vae_latents = []
         | 
| 465 | 
            +
                        for image_path in path:
         | 
| 466 | 
            +
                            with Image.open(image_path) as img:
         | 
| 467 | 
            +
                                img = img.convert("RGB")
         | 
| 468 | 
            +
             | 
| 469 | 
            +
                                # Calculate the required size to keep aspect ratio and fill the rest with padding.
         | 
| 470 | 
            +
                                img_ratio = img.width / img.height
         | 
| 471 | 
            +
                                target_ratio = w / h
         | 
| 472 | 
            +
                                
         | 
| 473 | 
            +
                                if img_ratio > target_ratio:  # Image is wider than target
         | 
| 474 | 
            +
                                    new_width = w
         | 
| 475 | 
            +
                                    new_height = int(new_width / img_ratio)
         | 
| 476 | 
            +
                                else:  # Image is taller than target
         | 
| 477 | 
            +
                                    new_height = h
         | 
| 478 | 
            +
                                    new_width = int(new_height * img_ratio)
         | 
| 479 | 
            +
                                
         | 
| 480 | 
            +
                                # img = img.resize((new_width, new_height), Image.ANTIALIAS)
         | 
| 481 | 
            +
                                img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
         | 
| 482 | 
            +
             | 
| 483 | 
            +
                                # Create a new image with the target size and place the resized image in the center
         | 
| 484 | 
            +
                                delta_w = w - img.size[0]
         | 
| 485 | 
            +
                                delta_h = h - img.size[1]
         | 
| 486 | 
            +
                                padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
         | 
| 487 | 
            +
                                new_img = ImageOps.expand(img, padding, fill=(255, 255, 255))
         | 
| 488 | 
            +
             | 
| 489 | 
            +
                                # Transform to tensor and normalize.
         | 
| 490 | 
            +
                                transform = Compose(
         | 
| 491 | 
            +
                                    [
         | 
| 492 | 
            +
                                        ToTensor(),
         | 
| 493 | 
            +
                                        Normalize(0.5, 0.5),
         | 
| 494 | 
            +
                                    ]
         | 
| 495 | 
            +
                                )
         | 
| 496 | 
            +
                                new_img = transform(new_img)
         | 
| 497 | 
            +
                                # img_vae_latent = self.vae_encode([new_img.unsqueeze(1)])[0]
         | 
| 498 | 
            +
                                img_vae_latent = self.vae.encode([new_img.unsqueeze(1)], device)
         | 
| 499 | 
            +
                                ref_vae_latents.append(img_vae_latent[0])
         | 
| 500 | 
            +
             | 
| 501 | 
            +
                        return [torch.cat(ref_vae_latents, dim=1)]
         | 
| 502 | 
            +
                    else:
         | 
| 503 | 
            +
                        if not isinstance(path, str):
         | 
| 504 | 
            +
                            path = path[0]
         | 
| 505 | 
            +
                        with Image.open(path) as img:
         | 
| 506 | 
            +
                            img = img.convert("RGB")
         | 
| 507 | 
            +
             | 
| 508 | 
            +
                            # Calculate the required size to keep aspect ratio and fill the rest with padding.
         | 
| 509 | 
            +
                            img_ratio = img.width / img.height
         | 
| 510 | 
            +
                            target_ratio = w / h
         | 
| 511 | 
            +
                            
         | 
| 512 | 
            +
                            if img_ratio > target_ratio:  # Image is wider than target
         | 
| 513 | 
            +
                                new_width = w
         | 
| 514 | 
            +
                                new_height = int(new_width / img_ratio)
         | 
| 515 | 
            +
                            else:  # Image is taller than target
         | 
| 516 | 
            +
                                new_height = h
         | 
| 517 | 
            +
                                new_width = int(new_height * img_ratio)
         | 
| 518 | 
            +
                            
         | 
| 519 | 
            +
                            # img = img.resize((new_width, new_height), Image.ANTIALIAS)
         | 
| 520 | 
            +
                            img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
         | 
| 521 | 
            +
             | 
| 522 | 
            +
                            # Create a new image with the target size and place the resized image in the center
         | 
| 523 | 
            +
                            delta_w = w - img.size[0]
         | 
| 524 | 
            +
                            delta_h = h - img.size[1]
         | 
| 525 | 
            +
                            padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
         | 
| 526 | 
            +
                            new_img = ImageOps.expand(img, padding, fill=(255, 255, 255))
         | 
| 527 | 
            +
             | 
| 528 | 
            +
                            # Transform to tensor and normalize.
         | 
| 529 | 
            +
                            transform = Compose(
         | 
| 530 | 
            +
                                [
         | 
| 531 | 
            +
                                    ToTensor(),
         | 
| 532 | 
            +
                                    Normalize(0.5, 0.5),
         | 
| 533 | 
            +
                                ]
         | 
| 534 | 
            +
                            )
         | 
| 535 | 
            +
                            new_img = transform(new_img)
         | 
| 536 | 
            +
                            img_vae_latent = self.vae.encode([new_img.unsqueeze(1)], device)
         | 
| 537 | 
            +
             | 
| 538 | 
            +
                        # Vae encode.
         | 
| 539 | 
            +
                        return img_vae_latent
         | 
| 540 | 
            +
                
         | 
| 541 | 
            +
                def get_audio_emb_window(self, audio_emb, frame_num, frame0_idx, audio_shift=2):
         | 
| 542 | 
            +
                    zero_audio_embed = torch.zeros((audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device)
         | 
| 543 | 
            +
                    zero_audio_embed_3 = torch.zeros((3, audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device)  # device=audio_emb.device
         | 
| 544 | 
            +
                    iter_ = 1 + (frame_num - 1) // 4
         | 
| 545 | 
            +
                    audio_emb_wind = []
         | 
| 546 | 
            +
                    for lt_i in range(iter_):
         | 
| 547 | 
            +
                        if lt_i == 0:
         | 
| 548 | 
            +
                            st = frame0_idx + lt_i - 2
         | 
| 549 | 
            +
                            ed = frame0_idx + lt_i + 3
         | 
| 550 | 
            +
                            wind_feat = torch.stack([
         | 
| 551 | 
            +
                                audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed
         | 
| 552 | 
            +
                                for i in range(st, ed)
         | 
| 553 | 
            +
                            ], dim=0)
         | 
| 554 | 
            +
                            wind_feat = torch.cat((zero_audio_embed_3, wind_feat), dim=0)
         | 
| 555 | 
            +
                        else:
         | 
| 556 | 
            +
                            st = frame0_idx + 1 + 4 * (lt_i - 1) - audio_shift
         | 
| 557 | 
            +
                            ed = frame0_idx + 1 + 4 * lt_i + audio_shift
         | 
| 558 | 
            +
                            wind_feat = torch.stack([
         | 
| 559 | 
            +
                                audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed
         | 
| 560 | 
            +
                                for i in range(st, ed)
         | 
| 561 | 
            +
                            ], dim=0)
         | 
| 562 | 
            +
                        audio_emb_wind.append(wind_feat)
         | 
| 563 | 
            +
                    audio_emb_wind = torch.stack(audio_emb_wind, dim=0)
         | 
| 564 | 
            +
             | 
| 565 | 
            +
                    return audio_emb_wind, ed - audio_shift
         | 
| 566 | 
            +
                
         | 
| 567 | 
            +
                def audio_emb_enc(self, audio_emb, wav_enc_type="whisper"):
         | 
| 568 | 
            +
                    if wav_enc_type == "wav2vec":
         | 
| 569 | 
            +
                        feat_merge = audio_emb
         | 
| 570 | 
            +
                    elif wav_enc_type == "whisper":
         | 
| 571 | 
            +
                        feat0 = linear_interpolation_fps(audio_emb[:, :, 0: 8].mean(dim=2), 50, 25)
         | 
| 572 | 
            +
                        feat1 = linear_interpolation_fps(audio_emb[:, :, 8: 16].mean(dim=2), 50, 25)
         | 
| 573 | 
            +
                        feat2 = linear_interpolation_fps(audio_emb[:, :, 16: 24].mean(dim=2), 50, 25)
         | 
| 574 | 
            +
                        feat3 = linear_interpolation_fps(audio_emb[:, :, 24: 32].mean(dim=2), 50, 25)
         | 
| 575 | 
            +
                        feat4 = linear_interpolation_fps(audio_emb[:, :, 32], 50, 25)
         | 
| 576 | 
            +
                        feat_merge = torch.stack([feat0, feat1, feat2, feat3, feat4], dim=2)[0]
         | 
| 577 | 
            +
                    else:
         | 
| 578 | 
            +
                        raise ValueError(f"Unsupported wav_enc_type: {wav_enc_type}")
         | 
| 579 | 
            +
                    
         | 
| 580 | 
            +
                    return feat_merge
         | 
| 581 | 
            +
                
         | 
| 582 | 
            +
                def parse_output(self, output):
         | 
| 583 | 
            +
                    latent = output[0]
         | 
| 584 | 
            +
                    mask = None
         | 
| 585 | 
            +
                    return latent, mask
         | 
| 586 | 
            +
                
         | 
| 587 | 
            +
                def forward_tia(self, latents, timestep, t, step_change, arg_tia, arg_ti, arg_i, arg_null):
         | 
| 588 | 
            +
                    pos_tia, _ = self.parse_output(self.dit(
         | 
| 589 | 
            +
                        latents, t=timestep, **arg_tia
         | 
| 590 | 
            +
                        ))
         | 
| 591 | 
            +
                    torch.cuda.empty_cache()
         | 
| 592 | 
            +
             | 
| 593 | 
            +
                    pos_ti, _ = self.parse_output(self.dit(
         | 
| 594 | 
            +
                        latents, t=timestep, **arg_ti
         | 
| 595 | 
            +
                        ))
         | 
| 596 | 
            +
                    torch.cuda.empty_cache()
         | 
| 597 | 
            +
             | 
| 598 | 
            +
                    if t > step_change:
         | 
| 599 | 
            +
                        neg, _ = self.parse_output(self.dit(
         | 
| 600 | 
            +
                            latents, t=timestep, **arg_i
         | 
| 601 | 
            +
                            ))  # img included in null, same with official Wan-2.1
         | 
| 602 | 
            +
                        torch.cuda.empty_cache()
         | 
| 603 | 
            +
             | 
| 604 | 
            +
                        noise_pred = self.config.generation.scale_a * (pos_tia - pos_ti) + \
         | 
| 605 | 
            +
                                self.config.generation.scale_t * (pos_ti - neg) + \
         | 
| 606 | 
            +
                                neg
         | 
| 607 | 
            +
                    else:
         | 
| 608 | 
            +
                        neg, _ = self.parse_output(self.dit(
         | 
| 609 | 
            +
                            latents, t=timestep, **arg_null
         | 
| 610 | 
            +
                            ))  # img not included in null
         | 
| 611 | 
            +
                        torch.cuda.empty_cache()
         | 
| 612 | 
            +
             | 
| 613 | 
            +
                        noise_pred = self.config.generation.scale_a * (pos_tia - pos_ti) + \
         | 
| 614 | 
            +
                                (self.config.generation.scale_t - 2.0) * (pos_ti - neg) + \
         | 
| 615 | 
            +
                                neg
         | 
| 616 | 
            +
                    return noise_pred
         | 
| 617 | 
            +
                
         | 
| 618 | 
            +
                def forward_ti(self, latents, timestep, t, step_change, arg_ti, arg_t, arg_i, arg_null):
         | 
| 619 | 
            +
                    # Positive with text+image (no audio)
         | 
| 620 | 
            +
                    pos_ti, _ = self.parse_output(self.dit(
         | 
| 621 | 
            +
                        latents, t=timestep, **arg_ti
         | 
| 622 | 
            +
                    ))
         | 
| 623 | 
            +
                    torch.cuda.empty_cache()
         | 
| 624 | 
            +
             | 
| 625 | 
            +
                    # Positive with text only (no image, no audio)
         | 
| 626 | 
            +
                    pos_t, _ = self.parse_output(self.dit(
         | 
| 627 | 
            +
                        latents, t=timestep, **arg_t
         | 
| 628 | 
            +
                    ))
         | 
| 629 | 
            +
                    torch.cuda.empty_cache()
         | 
| 630 | 
            +
             | 
| 631 | 
            +
                    # Negative branch: before step_change, don't include image in null; after, include image (like Wan-2.1)
         | 
| 632 | 
            +
                    if t > step_change:
         | 
| 633 | 
            +
                        neg, _ = self.parse_output(self.dit(
         | 
| 634 | 
            +
                            latents, t=timestep, **arg_i
         | 
| 635 | 
            +
                        ))  # img included in null
         | 
| 636 | 
            +
                    else:
         | 
| 637 | 
            +
                        neg, _ = self.parse_output(self.dit(
         | 
| 638 | 
            +
                            latents, t=timestep, **arg_null
         | 
| 639 | 
            +
                        ))  # img NOT included in null
         | 
| 640 | 
            +
                    torch.cuda.empty_cache()
         | 
| 641 | 
            +
             | 
| 642 | 
            +
                    # Guidance blend: replace "scale_a" below with "scale_i" if you add a separate image scale in config
         | 
| 643 | 
            +
                    noise_pred = self.config.generation.scale_a * (pos_ti - pos_t) + \
         | 
| 644 | 
            +
                                self.config.generation.scale_t * (pos_t - neg) + \
         | 
| 645 | 
            +
                                neg
         | 
| 646 | 
            +
                    return noise_pred
         | 
| 647 | 
            +
             | 
| 648 | 
            +
                def forward_ta(self, latents, timestep, arg_ta, arg_t, arg_null):
         | 
| 649 | 
            +
                    pos_ta, _ = self.parse_output(self.dit(
         | 
| 650 | 
            +
                        latents, t=timestep, **arg_ta
         | 
| 651 | 
            +
                        ))
         | 
| 652 | 
            +
                    torch.cuda.empty_cache()
         | 
| 653 | 
            +
             | 
| 654 | 
            +
                    pos_t, _ = self.parse_output(self.dit(
         | 
| 655 | 
            +
                        latents, t=timestep, **arg_t
         | 
| 656 | 
            +
                        ))
         | 
| 657 | 
            +
                    torch.cuda.empty_cache()
         | 
| 658 | 
            +
             | 
| 659 | 
            +
                    neg, _ = self.parse_output(self.dit(
         | 
| 660 | 
            +
                            latents, t=timestep, **arg_null
         | 
| 661 | 
            +
                            ))
         | 
| 662 | 
            +
                    torch.cuda.empty_cache()
         | 
| 663 | 
            +
                        
         | 
| 664 | 
            +
                    noise_pred = self.config.generation.scale_a * (pos_ta - pos_t) + \
         | 
| 665 | 
            +
                            self.config.generation.scale_t * (pos_t - neg) + \
         | 
| 666 | 
            +
                            neg
         | 
| 667 | 
            +
                    return noise_pred
         | 
| 668 | 
            +
                                
         | 
| 669 | 
            +
                @torch.no_grad()
         | 
| 670 | 
            +
                def inference(self,
         | 
| 671 | 
            +
                             input_prompt,
         | 
| 672 | 
            +
                             img_path,
         | 
| 673 | 
            +
                             audio_path,
         | 
| 674 | 
            +
                             size=(1280, 720),
         | 
| 675 | 
            +
                             frame_num=81,
         | 
| 676 | 
            +
                             shift=5.0,
         | 
| 677 | 
            +
                             sample_solver='unipc',
         | 
| 678 | 
            +
                             inference_mode='TIA',
         | 
| 679 | 
            +
                             sampling_steps=50,
         | 
| 680 | 
            +
                             n_prompt="",
         | 
| 681 | 
            +
                             seed=-1,
         | 
| 682 | 
            +
                             tea_cache_l1_thresh = 0.0,
         | 
| 683 | 
            +
                             device = get_device(),
         | 
| 684 | 
            +
                    ):
         | 
| 685 | 
            +
             | 
| 686 | 
            +
                    print("inference started")
         | 
| 687 | 
            +
             | 
| 688 | 
            +
                    # self.vae.model.to(device=device)
         | 
| 689 | 
            +
                    if img_path is not None:
         | 
| 690 | 
            +
                        latents_ref = self.load_image_latent_ref_id(img_path, size, device)
         | 
| 691 | 
            +
                    else:
         | 
| 692 | 
            +
                        latents_ref = [torch.zeros(16, 1, size[1]//8, size[0]//8).to(device)]
         | 
| 693 | 
            +
                        
         | 
| 694 | 
            +
                    # self.vae.model.to(device="cpu")
         | 
| 695 | 
            +
             | 
| 696 | 
            +
                    print("vae finished")
         | 
| 697 | 
            +
             | 
| 698 | 
            +
                    latents_ref_neg = [torch.zeros_like(latent_ref) for latent_ref in latents_ref]
         | 
| 699 | 
            +
                    
         | 
| 700 | 
            +
                    # audio
         | 
| 701 | 
            +
                    if audio_path is not None:
         | 
| 702 | 
            +
                        if self.config.generation.extract_audio_feat:
         | 
| 703 | 
            +
                            self.audio_processor.whisper.to(device=device)
         | 
| 704 | 
            +
                            audio_emb, audio_length = self.audio_processor.preprocess(audio_path)
         | 
| 705 | 
            +
                            self.audio_processor.whisper.to(device='cpu')
         | 
| 706 | 
            +
                        else:
         | 
| 707 | 
            +
                            audio_emb_path = audio_path.replace(".wav", ".pt")
         | 
| 708 | 
            +
                            audio_emb = torch.load(audio_emb_path).to(device=device)
         | 
| 709 | 
            +
                            audio_emb = self.audio_emb_enc(audio_emb, wav_enc_type="whisper")
         | 
| 710 | 
            +
                            self.logger.info("使用预先提取好的音频特征: %s", audio_emb_path)
         | 
| 711 | 
            +
                    else:
         | 
| 712 | 
            +
                        audio_emb = torch.zeros(frame_num, 5, 1280).to(device)
         | 
| 713 | 
            +
                        
         | 
| 714 | 
            +
                    frame_num = frame_num if frame_num != -1 else audio_length
         | 
| 715 | 
            +
                    frame_num = 4 * ((frame_num - 1) // 4) + 1
         | 
| 716 | 
            +
                    audio_emb, _ = self.get_audio_emb_window(audio_emb, frame_num, frame0_idx=0)
         | 
| 717 | 
            +
                    zero_audio_pad = torch.zeros(latents_ref[0].shape[1], *audio_emb.shape[1:]).to(audio_emb.device)
         | 
| 718 | 
            +
                    audio_emb = torch.cat([audio_emb, zero_audio_pad], dim=0)
         | 
| 719 | 
            +
                    audio_emb = [audio_emb.to(device)]
         | 
| 720 | 
            +
                    audio_emb_neg = [torch.zeros_like(audio_emb[0])]
         | 
| 721 | 
            +
                    
         | 
| 722 | 
            +
                    # preprocess
         | 
| 723 | 
            +
                    self.patch_size = self.config.dit.model.patch_size
         | 
| 724 | 
            +
                    F = frame_num
         | 
| 725 | 
            +
                    target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1 + latents_ref[0].shape[1],
         | 
| 726 | 
            +
                                    size[1] // self.vae_stride[1],
         | 
| 727 | 
            +
                                    size[0] // self.vae_stride[2])
         | 
| 728 | 
            +
             | 
| 729 | 
            +
                    seq_len = math.ceil((target_shape[2] * target_shape[3]) /
         | 
| 730 | 
            +
                                        (self.patch_size[1] * self.patch_size[2]) *
         | 
| 731 | 
            +
                                        target_shape[1] / self.sp_size) * self.sp_size
         | 
| 732 | 
            +
             | 
| 733 | 
            +
                    if n_prompt == "":
         | 
| 734 | 
            +
                        n_prompt = self.config.generation.sample_neg_prompt
         | 
| 735 | 
            +
                    seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
         | 
| 736 | 
            +
                    seed_g = torch.Generator(device=device)
         | 
| 737 | 
            +
                    seed_g.manual_seed(seed)
         | 
| 738 | 
            +
             | 
| 739 | 
            +
                    # self.text_encoder.model.to(device)
         | 
| 740 | 
            +
                    context = self.text_encoder([input_prompt], device)
         | 
| 741 | 
            +
                    context_null = self.text_encoder([n_prompt], device)
         | 
| 742 | 
            +
                    # self.text_encoder.model.cpu()
         | 
| 743 | 
            +
             | 
| 744 | 
            +
                    print("text encoder finished")
         | 
| 745 | 
            +
             | 
| 746 | 
            +
                    noise = [
         | 
| 747 | 
            +
                        torch.randn(
         | 
| 748 | 
            +
                            target_shape[0],
         | 
| 749 | 
            +
                            target_shape[1], # - latents_ref[0].shape[1],
         | 
| 750 | 
            +
                            target_shape[2],
         | 
| 751 | 
            +
                            target_shape[3],
         | 
| 752 | 
            +
                            dtype=torch.float32,
         | 
| 753 | 
            +
                            device=device,
         | 
| 754 | 
            +
                            generator=seed_g)
         | 
| 755 | 
            +
                    ]
         | 
| 756 | 
            +
             | 
| 757 | 
            +
                    @contextmanager
         | 
| 758 | 
            +
                    def noop_no_sync():
         | 
| 759 | 
            +
                        yield
         | 
| 760 | 
            +
             | 
| 761 | 
            +
                    no_sync = getattr(self.dit, 'no_sync', noop_no_sync)
         | 
| 762 | 
            +
                    step_change = self.config.generation.step_change # 980
         | 
| 763 | 
            +
             | 
| 764 | 
            +
                    # evaluation mode
         | 
| 765 | 
            +
                    with make_fp8_ctx(True), torch.autocast('cuda', dtype=torch.bfloat16), torch.no_grad(), no_sync():
         | 
| 766 | 
            +
             | 
| 767 | 
            +
                        if sample_solver == 'unipc':
         | 
| 768 | 
            +
                            sample_scheduler = FlowUniPCMultistepScheduler(
         | 
| 769 | 
            +
                                num_train_timesteps=1000,
         | 
| 770 | 
            +
                                shift=1,
         | 
| 771 | 
            +
                                use_dynamic_shifting=False)
         | 
| 772 | 
            +
                            sample_scheduler.set_timesteps(
         | 
| 773 | 
            +
                                sampling_steps, device=device, shift=shift)
         | 
| 774 | 
            +
                            timesteps = sample_scheduler.timesteps
         | 
| 775 | 
            +
             | 
| 776 | 
            +
                        # sample videos
         | 
| 777 | 
            +
                        latents = noise
         | 
| 778 | 
            +
             | 
| 779 | 
            +
                        msk = torch.ones(4, target_shape[1], target_shape[2], target_shape[3], device=get_device())
         | 
| 780 | 
            +
                        msk[:,:-latents_ref[0].shape[1]] = 0
         | 
| 781 | 
            +
             | 
| 782 | 
            +
                        zero_vae = self.zero_vae[:, :(target_shape[1]-latents_ref[0].shape[1])].to(
         | 
| 783 | 
            +
                            device=get_device(), dtype=latents_ref[0].dtype)
         | 
| 784 | 
            +
                        y_c = torch.cat([
         | 
| 785 | 
            +
                            zero_vae,
         | 
| 786 | 
            +
                            latents_ref[0]
         | 
| 787 | 
            +
                            ], dim=1)
         | 
| 788 | 
            +
                        y_c = [torch.concat([msk, y_c])]
         | 
| 789 | 
            +
             | 
| 790 | 
            +
                        y_null = self.zero_vae[:, :target_shape[1]].to(
         | 
| 791 | 
            +
                            device=get_device(), dtype=latents_ref[0].dtype)
         | 
| 792 | 
            +
                        y_null = [torch.concat([msk, y_null])]
         | 
| 793 | 
            +
             | 
| 794 | 
            +
                        tea_cache_l1_thresh = tea_cache_l1_thresh
         | 
| 795 | 
            +
                        tea_cache_model_id = "Wan2.1-T2V-14B"
         | 
| 796 | 
            +
             | 
| 797 | 
            +
                        arg_null = {'seq_len': seq_len, 'audio': audio_emb_neg, 'y': y_null, 'context': context_null, "tea_cache": TeaCache(sampling_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None and tea_cache_l1_thresh > 0 else None}
         | 
| 798 | 
            +
                        arg_t = {'seq_len': seq_len, 'audio': audio_emb_neg, 'y': y_null, 'context': context, "tea_cache": TeaCache(sampling_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None and tea_cache_l1_thresh > 0 else None}
         | 
| 799 | 
            +
                        arg_i = {'seq_len': seq_len, 'audio': audio_emb_neg, 'y': y_c, 'context': context_null, "tea_cache": TeaCache(sampling_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None and tea_cache_l1_thresh > 0 else None}
         | 
| 800 | 
            +
                        arg_ti = {'seq_len': seq_len, 'audio': audio_emb_neg, 'y': y_c, 'context': context, "tea_cache": TeaCache(sampling_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None and tea_cache_l1_thresh > 0 else None}
         | 
| 801 | 
            +
                        arg_ta = {'seq_len': seq_len, 'audio': audio_emb, 'y': y_null, 'context': context, "tea_cache": TeaCache(sampling_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None and tea_cache_l1_thresh > 0 else None}
         | 
| 802 | 
            +
                        arg_tia = {'seq_len': seq_len, 'audio': audio_emb, 'y': y_c, 'context': context, "tea_cache": TeaCache(sampling_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None and tea_cache_l1_thresh > 0 else None}
         | 
| 803 | 
            +
                        
         | 
| 804 | 
            +
                        torch.cuda.empty_cache()
         | 
| 805 | 
            +
                        # self.dit.to(device=get_device())
         | 
| 806 | 
            +
                        for _, t in enumerate(tqdm(timesteps)):
         | 
| 807 | 
            +
                            timestep = [t]
         | 
| 808 | 
            +
                            timestep = torch.stack(timestep)
         | 
| 809 | 
            +
             | 
| 810 | 
            +
                            if inference_mode == "TIA":
         | 
| 811 | 
            +
                                noise_pred = self.forward_tia(latents, timestep, t, step_change, 
         | 
| 812 | 
            +
                                                              arg_tia, arg_ti, arg_i, arg_null)
         | 
| 813 | 
            +
                            elif inference_mode == "TA":
         | 
| 814 | 
            +
                                noise_pred = self.forward_ta(latents, timestep, arg_ta, arg_t, arg_null)
         | 
| 815 | 
            +
                            elif inference_mode == "TI":
         | 
| 816 | 
            +
                                noise_pred = self.forward_ti(latents, timestep, t, step_change,
         | 
| 817 | 
            +
                                                            arg_ti, arg_t, arg_i, arg_null)
         | 
| 818 | 
            +
                            else:
         | 
| 819 | 
            +
                                raise ValueError(f"Unsupported generation mode: {self.config.generation.mode}")
         | 
| 820 | 
            +
             | 
| 821 | 
            +
                            temp_x0 = sample_scheduler.step(
         | 
| 822 | 
            +
                                noise_pred.unsqueeze(0),
         | 
| 823 | 
            +
                                t,
         | 
| 824 | 
            +
                                latents[0].unsqueeze(0),
         | 
| 825 | 
            +
                                return_dict=False,
         | 
| 826 | 
            +
                                generator=seed_g)[0]
         | 
| 827 | 
            +
                            latents = [temp_x0.squeeze(0)]
         | 
| 828 | 
            +
             | 
| 829 | 
            +
                            del timestep
         | 
| 830 | 
            +
                            torch.cuda.empty_cache()
         | 
| 831 | 
            +
             | 
| 832 | 
            +
                        x0 = latents
         | 
| 833 | 
            +
                        x0 = [x0_[:,:-latents_ref[0].shape[1]] for x0_ in x0]
         | 
| 834 | 
            +
             | 
| 835 | 
            +
                        # if offload_model:
         | 
| 836 | 
            +
                        # self.dit.cpu()
         | 
| 837 | 
            +
             | 
| 838 | 
            +
                        print("dit finished")
         | 
| 839 | 
            +
             | 
| 840 | 
            +
                        torch.cuda.empty_cache()
         | 
| 841 | 
            +
                        # if get_local_rank() == 0:
         | 
| 842 | 
            +
                        # self.vae.model.to(device=device)
         | 
| 843 | 
            +
                        videos = self.vae.decode(x0)
         | 
| 844 | 
            +
                        # self.vae.model.to(device="cpu")
         | 
| 845 | 
            +
             | 
| 846 | 
            +
                        print("vae 2 finished")
         | 
| 847 | 
            +
             | 
| 848 | 
            +
                    del noise, latents, noise_pred
         | 
| 849 | 
            +
                    del audio_emb, audio_emb_neg, latents_ref, latents_ref_neg, context, context_null
         | 
| 850 | 
            +
                    del x0, temp_x0
         | 
| 851 | 
            +
                    del sample_scheduler
         | 
| 852 | 
            +
                    torch.cuda.empty_cache()
         | 
| 853 | 
            +
                    gc.collect()
         | 
| 854 | 
            +
                    torch.cuda.synchronize()
         | 
| 855 | 
            +
                    if dist.is_initialized():
         | 
| 856 | 
            +
                        dist.barrier()
         | 
| 857 | 
            +
             | 
| 858 | 
            +
                    return videos[0] # if get_local_rank() == 0 else None
         | 
| 859 | 
            +
             | 
| 860 | 
            +
             | 
| 861 | 
            +
                def inference_loop(self, prompt, ref_img_path, audio_path, output_dir, filename, inference_mode = "TIA", width = 832, height = 480, steps=50, frames = 97, tea_cache_l1_thresh = 0.0, seed = 0):
         | 
| 862 | 
            +
             | 
| 863 | 
            +
                    video = self.inference(
         | 
| 864 | 
            +
                        prompt,
         | 
| 865 | 
            +
                        ref_img_path,
         | 
| 866 | 
            +
                        audio_path,
         | 
| 867 | 
            +
                        size=SIZE_CONFIGS[f"{width}*{height}"],
         | 
| 868 | 
            +
                        frame_num=frames,
         | 
| 869 | 
            +
                        shift=self.config.diffusion.timesteps.sampling.shift,
         | 
| 870 | 
            +
                        sample_solver='unipc',
         | 
| 871 | 
            +
                        sampling_steps=steps,
         | 
| 872 | 
            +
                        inference_mode = inference_mode,
         | 
| 873 | 
            +
                        tea_cache_l1_thresh = tea_cache_l1_thresh,
         | 
| 874 | 
            +
                        seed=seed
         | 
| 875 | 
            +
                    )
         | 
| 876 | 
            +
             | 
| 877 | 
            +
                    torch.cuda.empty_cache()
         | 
| 878 | 
            +
                    gc.collect()
         | 
| 879 | 
            +
                    
         | 
| 880 | 
            +
                    # Save samples.
         | 
| 881 | 
            +
                    if get_sequence_parallel_rank() == 0:
         | 
| 882 | 
            +
                        pathname = self.save_sample(
         | 
| 883 | 
            +
                            sample=video,
         | 
| 884 | 
            +
                            audio_path=audio_path,
         | 
| 885 | 
            +
                            output_dir = output_dir,
         | 
| 886 | 
            +
                            filename=filename,
         | 
| 887 | 
            +
                        )
         | 
| 888 | 
            +
                        self.logger.info(f"Finished {filename}, saved to {pathname}.")
         | 
| 889 | 
            +
                    
         | 
| 890 | 
            +
                    del video, prompt
         | 
| 891 | 
            +
                    torch.cuda.empty_cache()
         | 
| 892 | 
            +
                    gc.collect()
         | 
| 893 | 
            +
                        
         | 
| 894 | 
            +
             | 
| 895 | 
            +
                def save_sample(self, *, sample: torch.Tensor, audio_path: str, output_dir: str, filename: str):
         | 
| 896 | 
            +
                    gen_config = self.config.generation
         | 
| 897 | 
            +
                    # Prepare file path.
         | 
| 898 | 
            +
                    extension = ".mp4" if sample.ndim == 4 else ".png"
         | 
| 899 | 
            +
                    filename += extension
         | 
| 900 | 
            +
                    pathname = os.path.join(output_dir, filename)
         | 
| 901 | 
            +
                    # Convert sample.
         | 
| 902 | 
            +
                    sample = sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).to("cpu", torch.uint8)
         | 
| 903 | 
            +
                    sample = rearrange(sample, "c t h w -> t h w c")
         | 
| 904 | 
            +
                    # Save file.
         | 
| 905 | 
            +
                    if sample.ndim == 4:
         | 
| 906 | 
            +
                        if audio_path is not None:
         | 
| 907 | 
            +
                            tensor_to_video(
         | 
| 908 | 
            +
                                sample.numpy(),
         | 
| 909 | 
            +
                                pathname,
         | 
| 910 | 
            +
                                audio_path,
         | 
| 911 | 
            +
                                fps=gen_config.fps)
         | 
| 912 | 
            +
                        else:
         | 
| 913 | 
            +
                            mediapy.write_video(
         | 
| 914 | 
            +
                            path=pathname,
         | 
| 915 | 
            +
                            images=sample.numpy(),
         | 
| 916 | 
            +
                            fps=gen_config.fps,
         | 
| 917 | 
            +
                        )
         | 
| 918 | 
            +
                    else:
         | 
| 919 | 
            +
                        raise ValueError
         | 
| 920 | 
            +
                    return pathname
         | 
| 921 | 
            +
                
         | 
| 922 | 
            +
             | 
| 923 | 
            +
                def prepare_positive_prompts(self):
         | 
| 924 | 
            +
                    pos_prompts = self.config.generation.positive_prompt
         | 
| 925 | 
            +
                    if pos_prompts.endswith(".json"):
         | 
| 926 | 
            +
                        pos_prompts = prepare_json_dataset(pos_prompts)
         | 
| 927 | 
            +
                    else:
         | 
| 928 | 
            +
                        raise NotImplementedError
         | 
| 929 | 
            +
                    assert isinstance(pos_prompts, ListConfig)
         | 
| 930 | 
            +
             | 
| 931 | 
            +
                    return pos_prompts
         | 
| 932 | 
            +
                
         | 
| 933 | 
            +
            class TeaCache:
         | 
| 934 | 
            +
                def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
         | 
| 935 | 
            +
                    self.num_inference_steps = num_inference_steps
         | 
| 936 | 
            +
                    self.step = 0
         | 
| 937 | 
            +
                    self.accumulated_rel_l1_distance = 0
         | 
| 938 | 
            +
                    self.previous_modulated_input = None
         | 
| 939 | 
            +
                    self.rel_l1_thresh = rel_l1_thresh
         | 
| 940 | 
            +
                    self.previous_residual = None
         | 
| 941 | 
            +
                    self.previous_hidden_states = None
         | 
| 942 | 
            +
                    
         | 
| 943 | 
            +
                    self.coefficients_dict = {
         | 
| 944 | 
            +
                        "Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],
         | 
| 945 | 
            +
                        "Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],
         | 
| 946 | 
            +
                        "Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04,  1.40286849e+03, -1.35890334e+01, 1.32517977e-01],
         | 
| 947 | 
            +
                        "Wan2.1-I2V-14B-720P": [ 8.10705460e+03,  2.13393892e+03, -3.72934672e+02,  1.66203073e+01, -4.17769401e-02],
         | 
| 948 | 
            +
                    }
         | 
| 949 | 
            +
                    if model_id not in self.coefficients_dict:
         | 
| 950 | 
            +
                        supported_model_ids = ", ".join([i for i in self.coefficients_dict])
         | 
| 951 | 
            +
                        raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).")
         | 
| 952 | 
            +
                    self.coefficients = self.coefficients_dict[model_id]
         | 
| 953 | 
            +
             | 
| 954 | 
            +
                def check(self, dit, x, t_mod):
         | 
| 955 | 
            +
                    modulated_inp = t_mod.clone()
         | 
| 956 | 
            +
                    if self.step == 0 or self.step == self.num_inference_steps - 1:
         | 
| 957 | 
            +
                        should_calc = True
         | 
| 958 | 
            +
                        self.accumulated_rel_l1_distance = 0
         | 
| 959 | 
            +
                    else:
         | 
| 960 | 
            +
                        coefficients = self.coefficients
         | 
| 961 | 
            +
                        rescale_func = np.poly1d(coefficients)
         | 
| 962 | 
            +
                        self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
         | 
| 963 | 
            +
                        if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
         | 
| 964 | 
            +
                            should_calc = False
         | 
| 965 | 
            +
                        else:
         | 
| 966 | 
            +
                            should_calc = True
         | 
| 967 | 
            +
                            self.accumulated_rel_l1_distance = 0
         | 
| 968 | 
            +
                    self.previous_modulated_input = modulated_inp
         | 
| 969 | 
            +
                    self.step += 1
         | 
| 970 | 
            +
                    if self.step == self.num_inference_steps:
         | 
| 971 | 
            +
                        self.step = 0
         | 
| 972 | 
            +
                    if should_calc:
         | 
| 973 | 
            +
                        self.previous_hidden_states = x.clone()
         | 
| 974 | 
            +
                    return not should_calc
         | 
| 975 | 
            +
             | 
| 976 | 
            +
                def store(self, hidden_states):
         | 
| 977 | 
            +
                    if self.previous_hidden_states is None:
         | 
| 978 | 
            +
                        return
         | 
| 979 | 
            +
                    self.previous_residual = hidden_states - self.previous_hidden_states
         | 
| 980 | 
            +
                    self.previous_hidden_states = None
         | 
| 981 | 
            +
             | 
| 982 | 
            +
                def update(self, hidden_states):
         | 
| 983 | 
            +
                    hidden_states = hidden_states + self.previous_residual
         | 
| 984 | 
            +
                    return hidden_states
         | 
    	
        humo/generate_1_7B.py
    ADDED
    
    | @@ -0,0 +1,622 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
         | 
| 2 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 3 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 4 | 
            +
            # You may obtain a copy of the License at
         | 
| 5 | 
            +
            # http://www.apache.org/licenses/LICENSE-2.0
         | 
| 6 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 7 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 8 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 9 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 10 | 
            +
            # limitations under the License.
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            # Inference codes adapted from [SeedVR]
         | 
| 13 | 
            +
            # https://github.com/ByteDance-Seed/SeedVR/blob/main/projects/inference_seedvr2_7b.py
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            import math
         | 
| 16 | 
            +
            import os
         | 
| 17 | 
            +
            import gc
         | 
| 18 | 
            +
            import random
         | 
| 19 | 
            +
            import sys
         | 
| 20 | 
            +
            import mediapy
         | 
| 21 | 
            +
            import torch
         | 
| 22 | 
            +
            import torch.distributed as dist
         | 
| 23 | 
            +
            from omegaconf import DictConfig, ListConfig, OmegaConf
         | 
| 24 | 
            +
            from einops import rearrange
         | 
| 25 | 
            +
            from omegaconf import OmegaConf
         | 
| 26 | 
            +
            from PIL import Image, ImageOps
         | 
| 27 | 
            +
            from torchvision.transforms import ToTensor
         | 
| 28 | 
            +
            from tqdm import tqdm
         | 
| 29 | 
            +
            from torch.distributed.device_mesh import init_device_mesh
         | 
| 30 | 
            +
            from torch.distributed.fsdp import (
         | 
| 31 | 
            +
                BackwardPrefetch,
         | 
| 32 | 
            +
                FullyShardedDataParallel,
         | 
| 33 | 
            +
                MixedPrecision,
         | 
| 34 | 
            +
                ShardingStrategy,
         | 
| 35 | 
            +
            )
         | 
| 36 | 
            +
            from common.distributed import (
         | 
| 37 | 
            +
                get_device,
         | 
| 38 | 
            +
                get_global_rank,
         | 
| 39 | 
            +
                get_local_rank,
         | 
| 40 | 
            +
                meta_param_init_fn,
         | 
| 41 | 
            +
                meta_non_persistent_buffer_init_fn,
         | 
| 42 | 
            +
                init_torch,
         | 
| 43 | 
            +
            )
         | 
| 44 | 
            +
            from common.distributed.advanced import (
         | 
| 45 | 
            +
                init_unified_parallel,
         | 
| 46 | 
            +
                get_unified_parallel_world_size,
         | 
| 47 | 
            +
                get_sequence_parallel_rank,
         | 
| 48 | 
            +
                init_model_shard_cpu_group,
         | 
| 49 | 
            +
            )
         | 
| 50 | 
            +
            from common.logger import get_logger
         | 
| 51 | 
            +
            from common.config import create_object
         | 
| 52 | 
            +
            from common.distributed import get_device, get_global_rank
         | 
| 53 | 
            +
            from torchvision.transforms import Compose, Normalize, ToTensor
         | 
| 54 | 
            +
            from humo.models.wan_modules.t5 import T5EncoderModel
         | 
| 55 | 
            +
            from humo.models.wan_modules.vae import WanVAE
         | 
| 56 | 
            +
            from humo.models.utils.utils import tensor_to_video, prepare_json_dataset
         | 
| 57 | 
            +
            from contextlib import contextmanager
         | 
| 58 | 
            +
            import torch.cuda.amp as amp
         | 
| 59 | 
            +
            from humo.models.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
         | 
| 60 | 
            +
            from humo.utils.audio_processor_whisper import AudioProcessor
         | 
| 61 | 
            +
            from humo.utils.wav2vec import linear_interpolation_fps
         | 
| 62 | 
            +
             | 
| 63 | 
            +
             | 
| 64 | 
            +
            image_transform = Compose([
         | 
| 65 | 
            +
                ToTensor(),
         | 
| 66 | 
            +
                Normalize(mean=0.5, std=0.5),
         | 
| 67 | 
            +
            ])
         | 
| 68 | 
            +
             | 
| 69 | 
            +
            SIZE_CONFIGS = {
         | 
| 70 | 
            +
                '720*1280': (720, 1280),
         | 
| 71 | 
            +
                '1280*720': (1280, 720),
         | 
| 72 | 
            +
                '480*832': (480, 832),
         | 
| 73 | 
            +
                '832*480': (832, 480),
         | 
| 74 | 
            +
                '1024*1024': (1024, 1024),
         | 
| 75 | 
            +
            }
         | 
| 76 | 
            +
             | 
| 77 | 
            +
            def clever_format(nums, format="%.2f"):
         | 
| 78 | 
            +
                from typing import Iterable
         | 
| 79 | 
            +
                if not isinstance(nums, Iterable):
         | 
| 80 | 
            +
                    nums = [nums]
         | 
| 81 | 
            +
                clever_nums = []
         | 
| 82 | 
            +
                for num in nums:
         | 
| 83 | 
            +
                    if num > 1e12:
         | 
| 84 | 
            +
                        clever_nums.append(format % (num / 1e12) + "T")
         | 
| 85 | 
            +
                    elif num > 1e9:
         | 
| 86 | 
            +
                        clever_nums.append(format % (num / 1e9) + "G")
         | 
| 87 | 
            +
                    elif num > 1e6:
         | 
| 88 | 
            +
                        clever_nums.append(format % (num / 1e6) + "M")
         | 
| 89 | 
            +
                    elif num > 1e3:
         | 
| 90 | 
            +
                        clever_nums.append(format % (num / 1e3) + "K")
         | 
| 91 | 
            +
                    else:
         | 
| 92 | 
            +
                        clever_nums.append(format % num + "B")
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                clever_nums = clever_nums[0] if len(clever_nums) == 1 else (*clever_nums,)
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                return clever_nums
         | 
| 97 | 
            +
             | 
| 98 | 
            +
             | 
| 99 | 
            +
            class Generator():
         | 
| 100 | 
            +
                def __init__(self, config: DictConfig):
         | 
| 101 | 
            +
                    self.config = config.copy()
         | 
| 102 | 
            +
                    OmegaConf.set_readonly(self.config, True)
         | 
| 103 | 
            +
                    self.logger = get_logger(self.__class__.__name__)
         | 
| 104 | 
            +
                    self.configure_models()
         | 
| 105 | 
            +
                    
         | 
| 106 | 
            +
                    # init_torch(cudnn_benchmark=False)
         | 
| 107 | 
            +
                
         | 
| 108 | 
            +
                def get_fsdp_sharding_config(self, sharding_strategy, device_mesh_config):
         | 
| 109 | 
            +
                    device_mesh = None
         | 
| 110 | 
            +
                    fsdp_strategy = ShardingStrategy[sharding_strategy]
         | 
| 111 | 
            +
                    if (
         | 
| 112 | 
            +
                        fsdp_strategy in [ShardingStrategy._HYBRID_SHARD_ZERO2, ShardingStrategy.HYBRID_SHARD]
         | 
| 113 | 
            +
                        and device_mesh_config is not None
         | 
| 114 | 
            +
                    ):
         | 
| 115 | 
            +
                        device_mesh = init_device_mesh("cuda", tuple(device_mesh_config))
         | 
| 116 | 
            +
                    return device_mesh, fsdp_strategy
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                def configure_models(self):
         | 
| 119 | 
            +
                    self.configure_dit_model(device="cpu")
         | 
| 120 | 
            +
                    self.configure_vae_model()
         | 
| 121 | 
            +
                    if self.config.generation.get('extract_audio_feat', False):
         | 
| 122 | 
            +
                        self.configure_wav2vec(device="cpu")
         | 
| 123 | 
            +
                    self.configure_text_model(device="cpu")
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                    # Initialize fsdp.
         | 
| 126 | 
            +
                    self.configure_dit_fsdp_model()
         | 
| 127 | 
            +
                    self.configure_text_fsdp_model()
         | 
| 128 | 
            +
                
         | 
| 129 | 
            +
                def configure_dit_model(self, device=get_device()):
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                    init_unified_parallel(self.config.dit.sp_size)
         | 
| 132 | 
            +
                    self.sp_size = get_unified_parallel_world_size()
         | 
| 133 | 
            +
                    
         | 
| 134 | 
            +
                    # Create dit model.
         | 
| 135 | 
            +
                    init_device = "meta"
         | 
| 136 | 
            +
                    with torch.device(init_device):
         | 
| 137 | 
            +
                        self.dit = create_object(self.config.dit.model)
         | 
| 138 | 
            +
                    self.logger.info(f"Load DiT model on {init_device}.")
         | 
| 139 | 
            +
                    self.dit.eval().requires_grad_(False)
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                    # Load dit checkpoint.
         | 
| 142 | 
            +
                    path = self.config.dit.checkpoint_dir
         | 
| 143 | 
            +
                    if path.endswith(".pth"):
         | 
| 144 | 
            +
                        state = torch.load(path, map_location=device, mmap=True)
         | 
| 145 | 
            +
                        missing_keys, unexpected_keys = self.dit.load_state_dict(state, strict=False, assign=True)
         | 
| 146 | 
            +
                        self.logger.info(
         | 
| 147 | 
            +
                            f"dit loaded from {path}. "
         | 
| 148 | 
            +
                            f"Missing keys: {len(missing_keys)}, "
         | 
| 149 | 
            +
                            f"Unexpected keys: {len(unexpected_keys)}"
         | 
| 150 | 
            +
                        )
         | 
| 151 | 
            +
                    else:
         | 
| 152 | 
            +
                        from safetensors.torch import load_file
         | 
| 153 | 
            +
                        import json
         | 
| 154 | 
            +
                        def load_custom_sharded_weights(model_dir, base_name, device=device):
         | 
| 155 | 
            +
                            index_path = f"{model_dir}/{base_name}.safetensors.index.json"
         | 
| 156 | 
            +
                            with open(index_path, "r") as f:
         | 
| 157 | 
            +
                                index = json.load(f)
         | 
| 158 | 
            +
                            weight_map = index["weight_map"]
         | 
| 159 | 
            +
                            shard_files = set(weight_map.values())
         | 
| 160 | 
            +
                            state_dict = {}
         | 
| 161 | 
            +
                            for shard_file in shard_files:
         | 
| 162 | 
            +
                                shard_path = f"{model_dir}/{shard_file}"
         | 
| 163 | 
            +
                                shard_state = load_file(shard_path)
         | 
| 164 | 
            +
                                shard_state = {k: v.to(device) for k, v in shard_state.items()}
         | 
| 165 | 
            +
                                state_dict.update(shard_state)
         | 
| 166 | 
            +
                            return state_dict
         | 
| 167 | 
            +
                        state = load_custom_sharded_weights(path, 'humo', device)
         | 
| 168 | 
            +
                        self.dit.load_state_dict(state, strict=False, assign=True)
         | 
| 169 | 
            +
                    
         | 
| 170 | 
            +
                    self.dit = meta_non_persistent_buffer_init_fn(self.dit)
         | 
| 171 | 
            +
                    if device in [get_device(), "cuda"]:
         | 
| 172 | 
            +
                        self.dit.to(get_device())
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                    # Print model size.
         | 
| 175 | 
            +
                    params = sum(p.numel() for p in self.dit.parameters())
         | 
| 176 | 
            +
                    self.logger.info(
         | 
| 177 | 
            +
                        f"[RANK:{get_global_rank()}] DiT Parameters: {clever_format(params, '%.3f')}"
         | 
| 178 | 
            +
                    )
         | 
| 179 | 
            +
                
         | 
| 180 | 
            +
                def configure_vae_model(self, device=get_device()):
         | 
| 181 | 
            +
                    self.vae_stride = self.config.vae.vae_stride
         | 
| 182 | 
            +
                    self.vae = WanVAE(
         | 
| 183 | 
            +
                        vae_pth=self.config.vae.checkpoint,
         | 
| 184 | 
            +
                        device=device)
         | 
| 185 | 
            +
                    
         | 
| 186 | 
            +
                    if self.config.generation.height == 480:
         | 
| 187 | 
            +
                        self.zero_vae = torch.load(self.config.dit.zero_vae_path)
         | 
| 188 | 
            +
                    elif self.config.generation.height == 720:
         | 
| 189 | 
            +
                        self.zero_vae = torch.load(self.config.dit.zero_vae_720p_path)
         | 
| 190 | 
            +
                    else:
         | 
| 191 | 
            +
                        raise ValueError(f"Unsupported height {self.config.generation.height} for zero-vae.")
         | 
| 192 | 
            +
                
         | 
| 193 | 
            +
                def configure_wav2vec(self, device=get_device()):
         | 
| 194 | 
            +
                    audio_separator_model_file = self.config.audio.vocal_separator
         | 
| 195 | 
            +
                    wav2vec_model_path = self.config.audio.wav2vec_model
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                    self.audio_processor = AudioProcessor(
         | 
| 198 | 
            +
                        16000,
         | 
| 199 | 
            +
                        25,
         | 
| 200 | 
            +
                        wav2vec_model_path,
         | 
| 201 | 
            +
                        "all",
         | 
| 202 | 
            +
                        audio_separator_model_file,
         | 
| 203 | 
            +
                        None,  # not seperate
         | 
| 204 | 
            +
                        os.path.join(self.config.generation.output.dir, "vocals"),
         | 
| 205 | 
            +
                        device=device,
         | 
| 206 | 
            +
                    )
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                def configure_text_model(self, device=get_device()):
         | 
| 209 | 
            +
                    self.text_encoder = T5EncoderModel(
         | 
| 210 | 
            +
                        text_len=self.config.dit.model.text_len,
         | 
| 211 | 
            +
                        dtype=torch.bfloat16,
         | 
| 212 | 
            +
                        device=device,
         | 
| 213 | 
            +
                        checkpoint_path=self.config.text.t5_checkpoint,
         | 
| 214 | 
            +
                        tokenizer_path=self.config.text.t5_tokenizer,
         | 
| 215 | 
            +
                        )
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                
         | 
| 218 | 
            +
                def configure_dit_fsdp_model(self):
         | 
| 219 | 
            +
                    self.dit.to(get_device())
         | 
| 220 | 
            +
                    
         | 
| 221 | 
            +
                    return   
         | 
| 222 | 
            +
             | 
| 223 | 
            +
             | 
| 224 | 
            +
                def configure_text_fsdp_model(self):
         | 
| 225 | 
            +
                    self.text_encoder.to(get_device())
         | 
| 226 | 
            +
                    
         | 
| 227 | 
            +
                    return
         | 
| 228 | 
            +
             | 
| 229 | 
            +
             | 
| 230 | 
            +
                def load_image_latent_ref_id(self, path: str, size, device):
         | 
| 231 | 
            +
                    # Load size.
         | 
| 232 | 
            +
                    h, w = size[1], size[0]
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                    # Load image.
         | 
| 235 | 
            +
                    if len(path) > 1 and not isinstance(path, str):
         | 
| 236 | 
            +
                        ref_vae_latents = []
         | 
| 237 | 
            +
                        for image_path in path:
         | 
| 238 | 
            +
                            with Image.open(image_path) as img:
         | 
| 239 | 
            +
                                img = img.convert("RGB")
         | 
| 240 | 
            +
             | 
| 241 | 
            +
                                # Calculate the required size to keep aspect ratio and fill the rest with padding.
         | 
| 242 | 
            +
                                img_ratio = img.width / img.height
         | 
| 243 | 
            +
                                target_ratio = w / h
         | 
| 244 | 
            +
                                
         | 
| 245 | 
            +
                                if img_ratio > target_ratio:  # Image is wider than target
         | 
| 246 | 
            +
                                    new_width = w
         | 
| 247 | 
            +
                                    new_height = int(new_width / img_ratio)
         | 
| 248 | 
            +
                                else:  # Image is taller than target
         | 
| 249 | 
            +
                                    new_height = h
         | 
| 250 | 
            +
                                    new_width = int(new_height * img_ratio)
         | 
| 251 | 
            +
                                
         | 
| 252 | 
            +
                                # img = img.resize((new_width, new_height), Image.ANTIALIAS)
         | 
| 253 | 
            +
                                img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                                # Create a new image with the target size and place the resized image in the center
         | 
| 256 | 
            +
                                delta_w = w - img.size[0]
         | 
| 257 | 
            +
                                delta_h = h - img.size[1]
         | 
| 258 | 
            +
                                padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
         | 
| 259 | 
            +
                                new_img = ImageOps.expand(img, padding, fill=(255, 255, 255))
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                                # Transform to tensor and normalize.
         | 
| 262 | 
            +
                                transform = Compose(
         | 
| 263 | 
            +
                                    [
         | 
| 264 | 
            +
                                        ToTensor(),
         | 
| 265 | 
            +
                                        Normalize(0.5, 0.5),
         | 
| 266 | 
            +
                                    ]
         | 
| 267 | 
            +
                                )
         | 
| 268 | 
            +
                                new_img = transform(new_img)
         | 
| 269 | 
            +
                                # img_vae_latent = self.vae_encode([new_img.unsqueeze(1)])[0]
         | 
| 270 | 
            +
                                img_vae_latent = self.vae.encode([new_img.unsqueeze(1)], device)
         | 
| 271 | 
            +
                                ref_vae_latents.append(img_vae_latent[0])
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                        return [torch.cat(ref_vae_latents, dim=1)]
         | 
| 274 | 
            +
                    else:
         | 
| 275 | 
            +
                        if not isinstance(path, str):
         | 
| 276 | 
            +
                            path = path[0]
         | 
| 277 | 
            +
                        with Image.open(path) as img:
         | 
| 278 | 
            +
                            img = img.convert("RGB")
         | 
| 279 | 
            +
             | 
| 280 | 
            +
                            # Calculate the required size to keep aspect ratio and fill the rest with padding.
         | 
| 281 | 
            +
                            img_ratio = img.width / img.height
         | 
| 282 | 
            +
                            target_ratio = w / h
         | 
| 283 | 
            +
                            
         | 
| 284 | 
            +
                            if img_ratio > target_ratio:  # Image is wider than target
         | 
| 285 | 
            +
                                new_width = w
         | 
| 286 | 
            +
                                new_height = int(new_width / img_ratio)
         | 
| 287 | 
            +
                            else:  # Image is taller than target
         | 
| 288 | 
            +
                                new_height = h
         | 
| 289 | 
            +
                                new_width = int(new_height * img_ratio)
         | 
| 290 | 
            +
                            
         | 
| 291 | 
            +
                            # img = img.resize((new_width, new_height), Image.ANTIALIAS)
         | 
| 292 | 
            +
                            img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
         | 
| 293 | 
            +
             | 
| 294 | 
            +
                            # Create a new image with the target size and place the resized image in the center
         | 
| 295 | 
            +
                            delta_w = w - img.size[0]
         | 
| 296 | 
            +
                            delta_h = h - img.size[1]
         | 
| 297 | 
            +
                            padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
         | 
| 298 | 
            +
                            new_img = ImageOps.expand(img, padding, fill=(255, 255, 255))
         | 
| 299 | 
            +
             | 
| 300 | 
            +
                            # Transform to tensor and normalize.
         | 
| 301 | 
            +
                            transform = Compose(
         | 
| 302 | 
            +
                                [
         | 
| 303 | 
            +
                                    ToTensor(),
         | 
| 304 | 
            +
                                    Normalize(0.5, 0.5),
         | 
| 305 | 
            +
                                ]
         | 
| 306 | 
            +
                            )
         | 
| 307 | 
            +
                            new_img = transform(new_img)
         | 
| 308 | 
            +
                            img_vae_latent = self.vae.encode([new_img.unsqueeze(1)], device)
         | 
| 309 | 
            +
             | 
| 310 | 
            +
                        # Vae encode.
         | 
| 311 | 
            +
                        return img_vae_latent
         | 
| 312 | 
            +
                
         | 
| 313 | 
            +
                def get_audio_emb_window(self, audio_emb, frame_num, frame0_idx, audio_shift=2):
         | 
| 314 | 
            +
                    zero_audio_embed = torch.zeros((audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device)
         | 
| 315 | 
            +
                    zero_audio_embed_3 = torch.zeros((3, audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device)  # device=audio_emb.device
         | 
| 316 | 
            +
                    iter_ = 1 + (frame_num - 1) // 4
         | 
| 317 | 
            +
                    audio_emb_wind = []
         | 
| 318 | 
            +
                    for lt_i in range(iter_):
         | 
| 319 | 
            +
                        if lt_i == 0:
         | 
| 320 | 
            +
                            st = frame0_idx + lt_i - 2
         | 
| 321 | 
            +
                            ed = frame0_idx + lt_i + 3
         | 
| 322 | 
            +
                            wind_feat = torch.stack([
         | 
| 323 | 
            +
                                audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed
         | 
| 324 | 
            +
                                for i in range(st, ed)
         | 
| 325 | 
            +
                            ], dim=0)
         | 
| 326 | 
            +
                            wind_feat = torch.cat((zero_audio_embed_3, wind_feat), dim=0)
         | 
| 327 | 
            +
                        else:
         | 
| 328 | 
            +
                            st = frame0_idx + 1 + 4 * (lt_i - 1) - audio_shift
         | 
| 329 | 
            +
                            ed = frame0_idx + 1 + 4 * lt_i + audio_shift
         | 
| 330 | 
            +
                            wind_feat = torch.stack([
         | 
| 331 | 
            +
                                audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed
         | 
| 332 | 
            +
                                for i in range(st, ed)
         | 
| 333 | 
            +
                            ], dim=0)
         | 
| 334 | 
            +
                        audio_emb_wind.append(wind_feat)
         | 
| 335 | 
            +
                    audio_emb_wind = torch.stack(audio_emb_wind, dim=0)
         | 
| 336 | 
            +
             | 
| 337 | 
            +
                    return audio_emb_wind, ed - audio_shift
         | 
| 338 | 
            +
                
         | 
| 339 | 
            +
                def audio_emb_enc(self, audio_emb, wav_enc_type="whisper"):
         | 
| 340 | 
            +
                    if wav_enc_type == "wav2vec":
         | 
| 341 | 
            +
                        feat_merge = audio_emb
         | 
| 342 | 
            +
                    elif wav_enc_type == "whisper":
         | 
| 343 | 
            +
                        feat0 = linear_interpolation_fps(audio_emb[:, :, 0: 8].mean(dim=2), 50, 25)
         | 
| 344 | 
            +
                        feat1 = linear_interpolation_fps(audio_emb[:, :, 8: 16].mean(dim=2), 50, 25)
         | 
| 345 | 
            +
                        feat2 = linear_interpolation_fps(audio_emb[:, :, 16: 24].mean(dim=2), 50, 25)
         | 
| 346 | 
            +
                        feat3 = linear_interpolation_fps(audio_emb[:, :, 24: 32].mean(dim=2), 50, 25)
         | 
| 347 | 
            +
                        feat4 = linear_interpolation_fps(audio_emb[:, :, 32], 50, 25)
         | 
| 348 | 
            +
                        feat_merge = torch.stack([feat0, feat1, feat2, feat3, feat4], dim=2)[0]
         | 
| 349 | 
            +
                    else:
         | 
| 350 | 
            +
                        raise ValueError(f"Unsupported wav_enc_type: {wav_enc_type}")
         | 
| 351 | 
            +
                    
         | 
| 352 | 
            +
                    return feat_merge
         | 
| 353 | 
            +
                
         | 
| 354 | 
            +
                def forward_tia(self, latents, latents_ref, latents_ref_neg, timestep, arg_t, arg_ta, arg_null):
         | 
| 355 | 
            +
                    neg = self.dit(
         | 
| 356 | 
            +
                        [torch.cat([latent[:,:-latent_ref_neg.shape[1]], latent_ref_neg], dim=1) for latent, latent_ref_neg in zip(latents, latents_ref_neg)], t=timestep, **arg_null
         | 
| 357 | 
            +
                        )[0]
         | 
| 358 | 
            +
                    
         | 
| 359 | 
            +
                    pos_t = self.dit(
         | 
| 360 | 
            +
                        [torch.cat([latent[:,:-latent_ref_neg.shape[1]], latent_ref_neg], dim=1) for latent, latent_ref_neg in zip(latents, latents_ref_neg)], t=timestep, **arg_t
         | 
| 361 | 
            +
                        )[0]
         | 
| 362 | 
            +
                    pos_ta = self.dit(
         | 
| 363 | 
            +
                        [torch.cat([latent[:,:-latent_ref_neg.shape[1]], latent_ref_neg], dim=1) for latent, latent_ref_neg in zip(latents, latents_ref_neg)], t=timestep, **arg_ta
         | 
| 364 | 
            +
                        )[0]
         | 
| 365 | 
            +
                    pos_tia = self.dit(
         | 
| 366 | 
            +
                        [torch.cat([latent[:,:-latent_ref.shape[1]], latent_ref], dim=1) for latent, latent_ref in zip(latents, latents_ref)], t=timestep, **arg_ta
         | 
| 367 | 
            +
                        )[0]
         | 
| 368 | 
            +
                    
         | 
| 369 | 
            +
                    noise_pred = self.config.generation.scale_i * (pos_tia - pos_ta) + \
         | 
| 370 | 
            +
                                self.config.generation.scale_a * (pos_ta - pos_t) + \
         | 
| 371 | 
            +
                                self.config.generation.scale_t * (pos_t - neg) + \
         | 
| 372 | 
            +
                                neg
         | 
| 373 | 
            +
                    
         | 
| 374 | 
            +
                    return noise_pred
         | 
| 375 | 
            +
                
         | 
| 376 | 
            +
                def forward_ta(self, latents, latents_ref_neg, timestep, arg_t, arg_ta, arg_null):
         | 
| 377 | 
            +
                    neg = self.dit(
         | 
| 378 | 
            +
                        [torch.cat([latent[:,:-latent_ref_neg.shape[1]], latent_ref_neg], dim=1) for latent, latent_ref_neg in zip(latents, latents_ref_neg)], t=timestep, **arg_null
         | 
| 379 | 
            +
                        )[0]
         | 
| 380 | 
            +
                    
         | 
| 381 | 
            +
                    pos_t = self.dit(
         | 
| 382 | 
            +
                        [torch.cat([latent[:,:-latent_ref_neg.shape[1]], latent_ref_neg], dim=1) for latent, latent_ref_neg in zip(latents, latents_ref_neg)], t=timestep, **arg_t
         | 
| 383 | 
            +
                        )[0]
         | 
| 384 | 
            +
                    pos_ta = self.dit(
         | 
| 385 | 
            +
                        [torch.cat([latent[:,:-latent_ref_neg.shape[1]], latent_ref_neg], dim=1) for latent, latent_ref_neg in zip(latents, latents_ref_neg)], t=timestep, **arg_ta
         | 
| 386 | 
            +
                        )[0]
         | 
| 387 | 
            +
                    
         | 
| 388 | 
            +
                    noise_pred = self.config.generation.scale_a * (pos_ta - pos_t) + \
         | 
| 389 | 
            +
                                self.config.generation.scale_t * (pos_t - neg) + \
         | 
| 390 | 
            +
                                neg
         | 
| 391 | 
            +
                    
         | 
| 392 | 
            +
                    return noise_pred
         | 
| 393 | 
            +
                        
         | 
| 394 | 
            +
                                
         | 
| 395 | 
            +
                @torch.no_grad()
         | 
| 396 | 
            +
                def inference(self,
         | 
| 397 | 
            +
                             input_prompt,
         | 
| 398 | 
            +
                             img_path,
         | 
| 399 | 
            +
                             audio_path,
         | 
| 400 | 
            +
                             size=(1280, 720),
         | 
| 401 | 
            +
                             frame_num=81,
         | 
| 402 | 
            +
                             shift=5.0,
         | 
| 403 | 
            +
                             sample_solver='unipc',
         | 
| 404 | 
            +
                             sampling_steps=50,
         | 
| 405 | 
            +
                             n_prompt="",
         | 
| 406 | 
            +
                             seed=-1,
         | 
| 407 | 
            +
                             offload_model=True,
         | 
| 408 | 
            +
                             device = get_device(),
         | 
| 409 | 
            +
                    ):
         | 
| 410 | 
            +
             | 
| 411 | 
            +
                    self.vae.model.to(device=device)
         | 
| 412 | 
            +
                    if img_path is not None:
         | 
| 413 | 
            +
                        latents_ref = self.load_image_latent_ref_id(img_path, size, device)
         | 
| 414 | 
            +
                    else:
         | 
| 415 | 
            +
                        latents_ref = [torch.zeros(16, 1, size[1]//8, size[0]//8).to(device)]
         | 
| 416 | 
            +
                        
         | 
| 417 | 
            +
                    self.vae.model.to(device="cpu")
         | 
| 418 | 
            +
                    latents_ref_neg = [torch.zeros_like(latent_ref) for latent_ref in latents_ref]
         | 
| 419 | 
            +
                    
         | 
| 420 | 
            +
                    # audio
         | 
| 421 | 
            +
                    if audio_path is not None:
         | 
| 422 | 
            +
                        if self.config.generation.extract_audio_feat:
         | 
| 423 | 
            +
                            self.audio_processor.whisper.to(device=device)
         | 
| 424 | 
            +
                            audio_emb, audio_length = self.audio_processor.preprocess(audio_path)
         | 
| 425 | 
            +
                            self.audio_processor.whisper.to(device='cpu')
         | 
| 426 | 
            +
                        else:
         | 
| 427 | 
            +
                            audio_emb_path = audio_path.replace(".wav", ".pt")
         | 
| 428 | 
            +
                            audio_emb = torch.load(audio_emb_path).to(device=device)
         | 
| 429 | 
            +
                            audio_emb = self.audio_emb_enc(audio_emb, wav_enc_type="whisper")
         | 
| 430 | 
            +
                            self.logger.info("使用预先提取好的音频特征: %s", audio_emb_path)
         | 
| 431 | 
            +
                    else:
         | 
| 432 | 
            +
                        audio_emb = torch.zeros(frame_num, 5, 1280).to(device)
         | 
| 433 | 
            +
                        
         | 
| 434 | 
            +
                    frame_num = frame_num if frame_num != -1 else audio_length
         | 
| 435 | 
            +
                    frame_num = 4 * ((frame_num - 1) // 4) + 1
         | 
| 436 | 
            +
                    audio_emb, _ = self.get_audio_emb_window(audio_emb, frame_num, frame0_idx=0)
         | 
| 437 | 
            +
                    zero_audio_pad = torch.zeros(latents_ref[0].shape[1], *audio_emb.shape[1:]).to(audio_emb.device)
         | 
| 438 | 
            +
                    audio_emb = torch.cat([audio_emb, zero_audio_pad], dim=0)
         | 
| 439 | 
            +
                    audio_emb = [audio_emb.to(device)]
         | 
| 440 | 
            +
                    audio_emb_neg = [torch.zeros_like(audio_emb[0])]
         | 
| 441 | 
            +
                    
         | 
| 442 | 
            +
                    # preprocess
         | 
| 443 | 
            +
                    self.patch_size = self.config.dit.model.patch_size
         | 
| 444 | 
            +
                    F = frame_num
         | 
| 445 | 
            +
                    target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1 + latents_ref[0].shape[1],
         | 
| 446 | 
            +
                                    size[1] // self.vae_stride[1],
         | 
| 447 | 
            +
                                    size[0] // self.vae_stride[2])
         | 
| 448 | 
            +
             | 
| 449 | 
            +
                    seq_len = math.ceil((target_shape[2] * target_shape[3]) /
         | 
| 450 | 
            +
                                        (self.patch_size[1] * self.patch_size[2]) *
         | 
| 451 | 
            +
                                        target_shape[1] / self.sp_size) * self.sp_size
         | 
| 452 | 
            +
             | 
| 453 | 
            +
                    if n_prompt == "":
         | 
| 454 | 
            +
                        n_prompt = self.config.generation.sample_neg_prompt
         | 
| 455 | 
            +
                    seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
         | 
| 456 | 
            +
                    seed_g = torch.Generator(device=device)
         | 
| 457 | 
            +
                    seed_g.manual_seed(seed)
         | 
| 458 | 
            +
             | 
| 459 | 
            +
                    self.text_encoder.model.to(device)
         | 
| 460 | 
            +
                    context = self.text_encoder([input_prompt], device)
         | 
| 461 | 
            +
                    context_null = self.text_encoder([n_prompt], device)
         | 
| 462 | 
            +
                    self.text_encoder.model.cpu()
         | 
| 463 | 
            +
             | 
| 464 | 
            +
                    noise = [
         | 
| 465 | 
            +
                        torch.randn(
         | 
| 466 | 
            +
                            target_shape[0],
         | 
| 467 | 
            +
                            target_shape[1], # - latents_ref[0].shape[1],
         | 
| 468 | 
            +
                            target_shape[2],
         | 
| 469 | 
            +
                            target_shape[3],
         | 
| 470 | 
            +
                            dtype=torch.float32,
         | 
| 471 | 
            +
                            device=device,
         | 
| 472 | 
            +
                            generator=seed_g)
         | 
| 473 | 
            +
                    ]
         | 
| 474 | 
            +
             | 
| 475 | 
            +
                    @contextmanager
         | 
| 476 | 
            +
                    def noop_no_sync():
         | 
| 477 | 
            +
                        yield
         | 
| 478 | 
            +
             | 
| 479 | 
            +
                    no_sync = getattr(self.dit, 'no_sync', noop_no_sync)
         | 
| 480 | 
            +
                    # step_change = self.config.generation.step_change # 980
         | 
| 481 | 
            +
             | 
| 482 | 
            +
                    # evaluation mode
         | 
| 483 | 
            +
                    with amp.autocast(dtype=torch.bfloat16), torch.no_grad(), no_sync():
         | 
| 484 | 
            +
             | 
| 485 | 
            +
                        if sample_solver == 'unipc':
         | 
| 486 | 
            +
                            sample_scheduler = FlowUniPCMultistepScheduler(
         | 
| 487 | 
            +
                                num_train_timesteps=1000,
         | 
| 488 | 
            +
                                shift=1,
         | 
| 489 | 
            +
                                use_dynamic_shifting=False)
         | 
| 490 | 
            +
                            sample_scheduler.set_timesteps(
         | 
| 491 | 
            +
                                sampling_steps, device=device, shift=shift)
         | 
| 492 | 
            +
                            timesteps = sample_scheduler.timesteps
         | 
| 493 | 
            +
             | 
| 494 | 
            +
                        # sample videos
         | 
| 495 | 
            +
                        latents = noise
         | 
| 496 | 
            +
             | 
| 497 | 
            +
                        # referene image在下面的输入中手动指定, 不在arg中指定
         | 
| 498 | 
            +
                        arg_ta = {'context': context, 'seq_len': seq_len, 'audio': audio_emb}
         | 
| 499 | 
            +
                        arg_t = {'context': context, 'seq_len': seq_len, 'audio': audio_emb_neg}
         | 
| 500 | 
            +
                        arg_null = {'context': context_null, 'seq_len': seq_len, 'audio': audio_emb_neg}
         | 
| 501 | 
            +
                        
         | 
| 502 | 
            +
                        torch.cuda.empty_cache()
         | 
| 503 | 
            +
                        self.dit.to(device=get_device())
         | 
| 504 | 
            +
                        for _, t in enumerate(tqdm(timesteps)):
         | 
| 505 | 
            +
                            timestep = [t]
         | 
| 506 | 
            +
                            timestep = torch.stack(timestep)
         | 
| 507 | 
            +
             | 
| 508 | 
            +
                            if self.config.generation.mode == "TIA":
         | 
| 509 | 
            +
                                noise_pred = self.forward_tia(latents, latents_ref, latents_ref_neg, timestep, arg_t, arg_ta, arg_null)
         | 
| 510 | 
            +
                            elif self.config.generation.mode == "TA":
         | 
| 511 | 
            +
                                noise_pred = self.forward_ta(latents, latents_ref_neg, timestep, arg_t, arg_ta, arg_null)
         | 
| 512 | 
            +
                            else:
         | 
| 513 | 
            +
                                raise ValueError(f"Unsupported generation mode: {self.config.generation.mode}")
         | 
| 514 | 
            +
             | 
| 515 | 
            +
                            temp_x0 = sample_scheduler.step(
         | 
| 516 | 
            +
                                noise_pred.unsqueeze(0),
         | 
| 517 | 
            +
                                t,
         | 
| 518 | 
            +
                                latents[0].unsqueeze(0),
         | 
| 519 | 
            +
                                return_dict=False,
         | 
| 520 | 
            +
                                generator=seed_g)[0]
         | 
| 521 | 
            +
                            latents = [temp_x0.squeeze(0)]
         | 
| 522 | 
            +
             | 
| 523 | 
            +
                            del timestep
         | 
| 524 | 
            +
                            torch.cuda.empty_cache()
         | 
| 525 | 
            +
             | 
| 526 | 
            +
                        x0 = latents
         | 
| 527 | 
            +
                        x0 = [x0_[:,:-latents_ref[0].shape[1]] for x0_ in x0]
         | 
| 528 | 
            +
             | 
| 529 | 
            +
                        # if offload_model:
         | 
| 530 | 
            +
                        self.dit.cpu()
         | 
| 531 | 
            +
                        torch.cuda.empty_cache()
         | 
| 532 | 
            +
                        # if get_local_rank() == 0:
         | 
| 533 | 
            +
                        self.vae.model.to(device=device)
         | 
| 534 | 
            +
                        videos = self.vae.decode(x0)
         | 
| 535 | 
            +
                        self.vae.model.to(device="cpu")
         | 
| 536 | 
            +
             | 
| 537 | 
            +
                    del noise, latents, noise_pred
         | 
| 538 | 
            +
                    del audio_emb, audio_emb_neg, latents_ref, latents_ref_neg, context, context_null
         | 
| 539 | 
            +
                    del x0, temp_x0
         | 
| 540 | 
            +
                    del sample_scheduler
         | 
| 541 | 
            +
                    torch.cuda.empty_cache()
         | 
| 542 | 
            +
                    gc.collect()
         | 
| 543 | 
            +
                    torch.cuda.synchronize()
         | 
| 544 | 
            +
                    if dist.is_initialized():
         | 
| 545 | 
            +
                        dist.barrier()
         | 
| 546 | 
            +
             | 
| 547 | 
            +
                    return videos[0] # if get_local_rank() == 0 else None
         | 
| 548 | 
            +
             | 
| 549 | 
            +
             | 
| 550 | 
            +
                def inference_loop(self, prompt, ref_img_path, audio_path, output_dir, filename, width = 832, height = 480, steps=50, frames = 97, seed = 0):
         | 
| 551 | 
            +
                    print(f'ref_img_path:{ref_img_path}')
         | 
| 552 | 
            +
             | 
| 553 | 
            +
                    video = self.inference(
         | 
| 554 | 
            +
                        prompt,
         | 
| 555 | 
            +
                        ref_img_path,
         | 
| 556 | 
            +
                        audio_path,
         | 
| 557 | 
            +
                        size=SIZE_CONFIGS[f"{width}*{height}"],
         | 
| 558 | 
            +
                        frame_num=frames,
         | 
| 559 | 
            +
                        shift=self.config.diffusion.timesteps.sampling.shift,
         | 
| 560 | 
            +
                        sample_solver='unipc',
         | 
| 561 | 
            +
                        sampling_steps=steps,
         | 
| 562 | 
            +
                        seed=seed,
         | 
| 563 | 
            +
                        offload_model=False,
         | 
| 564 | 
            +
                    )
         | 
| 565 | 
            +
             | 
| 566 | 
            +
                    torch.cuda.empty_cache()
         | 
| 567 | 
            +
                    gc.collect()
         | 
| 568 | 
            +
                    
         | 
| 569 | 
            +
             | 
| 570 | 
            +
                    # Save samples.
         | 
| 571 | 
            +
                    if get_sequence_parallel_rank() == 0:
         | 
| 572 | 
            +
                        pathname = self.save_sample(
         | 
| 573 | 
            +
                            sample=video,
         | 
| 574 | 
            +
                            audio_path=audio_path,
         | 
| 575 | 
            +
                            output_dir = output_dir,
         | 
| 576 | 
            +
                            filename=filename,
         | 
| 577 | 
            +
                        )
         | 
| 578 | 
            +
                        self.logger.info(f"Finished {filename}, saved to {pathname}.")
         | 
| 579 | 
            +
                    
         | 
| 580 | 
            +
                    del video, prompt
         | 
| 581 | 
            +
                    torch.cuda.empty_cache()
         | 
| 582 | 
            +
                    gc.collect()
         | 
| 583 | 
            +
             | 
| 584 | 
            +
                        
         | 
| 585 | 
            +
             | 
| 586 | 
            +
                def save_sample(self, *, sample: torch.Tensor, audio_path: str, output_dir: str, filename: str):
         | 
| 587 | 
            +
                    gen_config = self.config.generation
         | 
| 588 | 
            +
                    # Prepare file path.
         | 
| 589 | 
            +
                    extension = ".mp4" if sample.ndim == 4 else ".png"
         | 
| 590 | 
            +
                    filename += extension
         | 
| 591 | 
            +
                    pathname = os.path.join(output_dir, filename)
         | 
| 592 | 
            +
                    # Convert sample.
         | 
| 593 | 
            +
                    sample = sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).to("cpu", torch.uint8)
         | 
| 594 | 
            +
                    sample = rearrange(sample, "c t h w -> t h w c")
         | 
| 595 | 
            +
                    # Save file.
         | 
| 596 | 
            +
                    if sample.ndim == 4:
         | 
| 597 | 
            +
                        if audio_path is not None:
         | 
| 598 | 
            +
                            tensor_to_video(
         | 
| 599 | 
            +
                                sample.numpy(),
         | 
| 600 | 
            +
                                pathname,
         | 
| 601 | 
            +
                                audio_path,
         | 
| 602 | 
            +
                                fps=gen_config.fps)
         | 
| 603 | 
            +
                        else:
         | 
| 604 | 
            +
                            mediapy.write_video(
         | 
| 605 | 
            +
                            path=pathname,
         | 
| 606 | 
            +
                            images=sample.numpy(),
         | 
| 607 | 
            +
                            fps=gen_config.fps,
         | 
| 608 | 
            +
                        )
         | 
| 609 | 
            +
                    else:
         | 
| 610 | 
            +
                        raise ValueError
         | 
| 611 | 
            +
                    return pathname
         | 
| 612 | 
            +
                
         | 
| 613 | 
            +
             | 
| 614 | 
            +
                def prepare_positive_prompts(self):
         | 
| 615 | 
            +
                    pos_prompts = self.config.generation.positive_prompt
         | 
| 616 | 
            +
                    if pos_prompts.endswith(".json"):
         | 
| 617 | 
            +
                        pos_prompts = prepare_json_dataset(pos_prompts)
         | 
| 618 | 
            +
                    else:
         | 
| 619 | 
            +
                        raise NotImplementedError
         | 
| 620 | 
            +
                    assert isinstance(pos_prompts, ListConfig)
         | 
| 621 | 
            +
             | 
| 622 | 
            +
                    return pos_prompts
         | 
    	
        humo/models/audio/audio_proj.py
    ADDED
    
    | @@ -0,0 +1,87 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            from einops import rearrange
         | 
| 3 | 
            +
            from torch import nn
         | 
| 4 | 
            +
            from einops import rearrange
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            class WanRMSNorm(nn.Module):
         | 
| 7 | 
            +
             | 
| 8 | 
            +
                def __init__(self, dim, eps=1e-5):
         | 
| 9 | 
            +
                    super().__init__()
         | 
| 10 | 
            +
                    self.dim = dim
         | 
| 11 | 
            +
                    self.eps = eps
         | 
| 12 | 
            +
                    self.weight = nn.Parameter(torch.ones(dim))
         | 
| 13 | 
            +
             | 
| 14 | 
            +
                def forward(self, x):
         | 
| 15 | 
            +
                    r"""
         | 
| 16 | 
            +
                    Args:
         | 
| 17 | 
            +
                        x(Tensor): Shape [B, L, C]
         | 
| 18 | 
            +
                    """
         | 
| 19 | 
            +
                    return self._norm(x.float()).type_as(x) * self.weight
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                def _norm(self, x):
         | 
| 22 | 
            +
                    return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            class DummyAdapterLayer(nn.Module):
         | 
| 26 | 
            +
                def __init__(self, layer):
         | 
| 27 | 
            +
                    super().__init__()
         | 
| 28 | 
            +
                    self.layer = layer
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                def forward(self, *args, **kwargs):
         | 
| 31 | 
            +
                    return self.layer(*args, **kwargs)
         | 
| 32 | 
            +
                
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            class AudioProjModel(nn.Module):
         | 
| 35 | 
            +
                def __init__(
         | 
| 36 | 
            +
                    self,
         | 
| 37 | 
            +
                    seq_len=5,
         | 
| 38 | 
            +
                    blocks=13,  # add a new parameter blocks
         | 
| 39 | 
            +
                    channels=768,  # add a new parameter channels
         | 
| 40 | 
            +
                    intermediate_dim=512,
         | 
| 41 | 
            +
                    output_dim=1536,
         | 
| 42 | 
            +
                    context_tokens=16,
         | 
| 43 | 
            +
                ):
         | 
| 44 | 
            +
                    super().__init__()
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                    self.seq_len = seq_len
         | 
| 47 | 
            +
                    self.blocks = blocks
         | 
| 48 | 
            +
                    self.channels = channels
         | 
| 49 | 
            +
                    self.input_dim = seq_len * blocks * channels  # update input_dim to be the product of blocks and channels.
         | 
| 50 | 
            +
                    self.intermediate_dim = intermediate_dim
         | 
| 51 | 
            +
                    self.context_tokens = context_tokens
         | 
| 52 | 
            +
                    self.output_dim = output_dim
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                    # define multiple linear layers
         | 
| 55 | 
            +
                    self.audio_proj_glob_1 = DummyAdapterLayer(nn.Linear(self.input_dim, intermediate_dim))
         | 
| 56 | 
            +
                    self.audio_proj_glob_2 = DummyAdapterLayer(nn.Linear(intermediate_dim, intermediate_dim))
         | 
| 57 | 
            +
                    self.audio_proj_glob_3 = DummyAdapterLayer(nn.Linear(intermediate_dim, context_tokens * output_dim))
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                    self.audio_proj_glob_norm = DummyAdapterLayer(nn.LayerNorm(output_dim))
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                    self.initialize_weights()
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                def initialize_weights(self):
         | 
| 64 | 
            +
                    # Initialize transformer layers:
         | 
| 65 | 
            +
                    def _basic_init(module):
         | 
| 66 | 
            +
                        if isinstance(module, nn.Linear):
         | 
| 67 | 
            +
                            torch.nn.init.xavier_uniform_(module.weight)
         | 
| 68 | 
            +
                            if module.bias is not None:
         | 
| 69 | 
            +
                                nn.init.constant_(module.bias, 0)
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                    self.apply(_basic_init)
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                def forward(self, audio_embeds):
         | 
| 74 | 
            +
                    video_length = audio_embeds.shape[1]
         | 
| 75 | 
            +
                    audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c")
         | 
| 76 | 
            +
                    batch_size, window_size, blocks, channels = audio_embeds.shape
         | 
| 77 | 
            +
                    audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels)
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                    audio_embeds = torch.relu(self.audio_proj_glob_1(audio_embeds))
         | 
| 80 | 
            +
                    audio_embeds = torch.relu(self.audio_proj_glob_2(audio_embeds))
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                    context_tokens = self.audio_proj_glob_3(audio_embeds).reshape(batch_size, self.context_tokens, self.output_dim)
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                    context_tokens = self.audio_proj_glob_norm(context_tokens)
         | 
| 85 | 
            +
                    context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length)
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                    return context_tokens
         | 
    	
        humo/models/distributed/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        humo/models/distributed/dit_ulysses_sequence_parallel.py
    ADDED
    
    | @@ -0,0 +1,270 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
         | 
| 2 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 3 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 4 | 
            +
            # You may obtain a copy of the License at
         | 
| 5 | 
            +
            # http://www.apache.org/licenses/LICENSE-2.0
         | 
| 6 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 7 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 8 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 9 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 10 | 
            +
            # limitations under the License.
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import torch
         | 
| 13 | 
            +
            from einops import rearrange
         | 
| 14 | 
            +
            from common.distributed import get_device
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            from common.distributed.advanced import (
         | 
| 17 | 
            +
                get_unified_parallel_world_size,
         | 
| 18 | 
            +
                get_unified_parallel_group,
         | 
| 19 | 
            +
                pad_tensor,
         | 
| 20 | 
            +
                Slice,
         | 
| 21 | 
            +
                gather_outputs,
         | 
| 22 | 
            +
                gather_seq_scatter_heads_qkv,
         | 
| 23 | 
            +
                gather_seq_scatter_double_head,
         | 
| 24 | 
            +
                gather_heads_scatter_seq,
         | 
| 25 | 
            +
                unpad_tensor
         | 
| 26 | 
            +
            )
         | 
| 27 | 
            +
            from humo.models.wan_modules.attention import flash_attention
         | 
| 28 | 
            +
            from humo.models.wan_modules.model_humo import rope_apply, sinusoidal_embedding_1d
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            def ulysses_dit_forward(
         | 
| 32 | 
            +
                self,
         | 
| 33 | 
            +
                x,
         | 
| 34 | 
            +
                t,
         | 
| 35 | 
            +
                context,
         | 
| 36 | 
            +
                seq_len,
         | 
| 37 | 
            +
                audio=None,
         | 
| 38 | 
            +
                y=None
         | 
| 39 | 
            +
            ):
         | 
| 40 | 
            +
                """
         | 
| 41 | 
            +
                x:              A list of videos each with shape [C, T, H, W].
         | 
| 42 | 
            +
                t:              [B].
         | 
| 43 | 
            +
                context:        A list of text embeddings each with shape [L, C].
         | 
| 44 | 
            +
                """
         | 
| 45 | 
            +
                if self.model_type == 'i2v':
         | 
| 46 | 
            +
                    # assert clip_fea is not None and y is not None
         | 
| 47 | 
            +
                    assert y is not None
         | 
| 48 | 
            +
                # params
         | 
| 49 | 
            +
                device = self.patch_embedding.weight.device
         | 
| 50 | 
            +
                if self.freqs.device != device:
         | 
| 51 | 
            +
                    self.freqs = self.freqs.to(device)
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                if y is not None:
         | 
| 54 | 
            +
                    x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                # embeddings
         | 
| 57 | 
            +
                x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
         | 
| 58 | 
            +
                grid_sizes = torch.stack(
         | 
| 59 | 
            +
                    [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
         | 
| 60 | 
            +
                x = [u.flatten(2).transpose(1, 2) for u in x]
         | 
| 61 | 
            +
                seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long, device=device)
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                assert seq_lens.max() <= seq_len
         | 
| 64 | 
            +
                x = torch.cat([
         | 
| 65 | 
            +
                    torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
         | 
| 66 | 
            +
                    for u in x
         | 
| 67 | 
            +
                ])
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                # time embeddings
         | 
| 70 | 
            +
                with torch.amp.autocast('cuda', dtype=torch.float32):
         | 
| 71 | 
            +
                    e = self.time_embedding(
         | 
| 72 | 
            +
                        sinusoidal_embedding_1d(self.freq_dim, t).float()).float()
         | 
| 73 | 
            +
                    e0 = self.time_projection(e).unflatten(1, (6, self.dim)).float()
         | 
| 74 | 
            +
                    assert e.dtype == torch.float32 and e0.dtype == torch.float32
         | 
| 75 | 
            +
                    
         | 
| 76 | 
            +
                # context
         | 
| 77 | 
            +
                context_lens = None
         | 
| 78 | 
            +
                context = self.text_embedding(
         | 
| 79 | 
            +
                    torch.stack([
         | 
| 80 | 
            +
                        torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
         | 
| 81 | 
            +
                        for u in context
         | 
| 82 | 
            +
                    ]))
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                if self.insert_audio:
         | 
| 85 | 
            +
                    audio = [self.audio_proj(au.unsqueeze(0)).permute(0, 3, 1, 2) for au in audio]
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                    audio_seq_len = torch.tensor(max([au.shape[2] for au in audio]) * audio[0].shape[3], device=get_device())
         | 
| 88 | 
            +
                    audio = [au.flatten(2).transpose(1, 2) for au in audio] # [1, t*32, 1536]
         | 
| 89 | 
            +
                    audio_seq_lens = torch.tensor([au.size(1) for au in audio], dtype=torch.long, device=device)
         | 
| 90 | 
            +
                    audio = torch.cat([
         | 
| 91 | 
            +
                        torch.cat([au, au.new_zeros(1, audio_seq_len - au.size(1), au.size(2))],
         | 
| 92 | 
            +
                                    dim=1) for au in audio
         | 
| 93 | 
            +
                    ])
         | 
| 94 | 
            +
                else:
         | 
| 95 | 
            +
                    audio = None
         | 
| 96 | 
            +
                    audio_seq_len = None
         | 
| 97 | 
            +
                    audio_seq_lens = None
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                # ulysses support
         | 
| 100 | 
            +
                sp_world = get_unified_parallel_world_size()
         | 
| 101 | 
            +
                group = get_unified_parallel_group()
         | 
| 102 | 
            +
                if seq_len % sp_world:
         | 
| 103 | 
            +
                    padding_size = sp_world - (seq_len % sp_world)
         | 
| 104 | 
            +
                    x = pad_tensor(x, dim=1, padding_size=padding_size)
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                    if self.insert_audio:
         | 
| 107 | 
            +
                        audio_padding_size = sp_world - (audio_seq_len % sp_world)
         | 
| 108 | 
            +
                        audio = pad_tensor(audio, dim=1, padding_size=audio_padding_size)
         | 
| 109 | 
            +
                    
         | 
| 110 | 
            +
                x = Slice.apply(group, x, 1, True)
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                if self.insert_audio:
         | 
| 113 | 
            +
                    audio = Slice.apply(group, audio, 1, True)
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                # arguments
         | 
| 116 | 
            +
                kwargs = dict(
         | 
| 117 | 
            +
                    e=e0,
         | 
| 118 | 
            +
                    seq_lens=seq_lens,
         | 
| 119 | 
            +
                    grid_sizes=grid_sizes,
         | 
| 120 | 
            +
                    freqs=self.freqs,
         | 
| 121 | 
            +
                    context=context,
         | 
| 122 | 
            +
                    context_lens=context_lens,
         | 
| 123 | 
            +
                    audio=audio,
         | 
| 124 | 
            +
                    audio_seq_len=audio_seq_len)
         | 
| 125 | 
            +
                
         | 
| 126 | 
            +
                for block in self.blocks:
         | 
| 127 | 
            +
                    x = block(x, **kwargs)
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                # head
         | 
| 130 | 
            +
                x = self.head(x, e)
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                # ulysses support
         | 
| 133 | 
            +
                x = gather_outputs(x, gather_dim=1, padding_dim=1, unpad_dim_size=seq_len, scale_grad=True)
         | 
| 134 | 
            +
                
         | 
| 135 | 
            +
                # unpatchify
         | 
| 136 | 
            +
                x = self.unpatchify(x, grid_sizes)
         | 
| 137 | 
            +
                return [u.float() for u in x]
         | 
| 138 | 
            +
             | 
| 139 | 
            +
             | 
| 140 | 
            +
            def ulysses_attn_forward(
         | 
| 141 | 
            +
                self,
         | 
| 142 | 
            +
                x,
         | 
| 143 | 
            +
                seq_lens,
         | 
| 144 | 
            +
                grid_sizes,
         | 
| 145 | 
            +
                freqs,
         | 
| 146 | 
            +
                dtype=torch.bfloat16
         | 
| 147 | 
            +
            ):
         | 
| 148 | 
            +
                b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
         | 
| 149 | 
            +
                seq_len = seq_lens.max()
         | 
| 150 | 
            +
                half_dtypes = (torch.float16, torch.bfloat16)
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                def half(x):
         | 
| 153 | 
            +
                    return x if x.dtype in half_dtypes else x.to(dtype)
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                # query, key, value function
         | 
| 156 | 
            +
                def qkv_fn(x):
         | 
| 157 | 
            +
                    q = self.norm_q(self.q(x))
         | 
| 158 | 
            +
                    k = self.norm_k(self.k(x))
         | 
| 159 | 
            +
                    v = self.v(x)
         | 
| 160 | 
            +
                    return q, k, v
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                q, k, v = qkv_fn(x)
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                # ulysses support
         | 
| 165 | 
            +
                sp_size = get_unified_parallel_world_size()
         | 
| 166 | 
            +
                if n % sp_size:
         | 
| 167 | 
            +
                    pad_size = sp_size - (n % sp_size)
         | 
| 168 | 
            +
                    pad_size = pad_size * d
         | 
| 169 | 
            +
                    pad_inner_dim = n * d + pad_size
         | 
| 170 | 
            +
                    q = pad_tensor(q, dim=2, padding_size=pad_size)
         | 
| 171 | 
            +
                    k = pad_tensor(k, dim=2, padding_size=pad_size)
         | 
| 172 | 
            +
                    v = pad_tensor(v, dim=2, padding_size=pad_size)
         | 
| 173 | 
            +
                else:
         | 
| 174 | 
            +
                    pad_inner_dim = n * d
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                qkv = torch.cat([q, k, v], dim=2)
         | 
| 177 | 
            +
                qkv = gather_seq_scatter_heads_qkv(qkv, seq_dim=1, unpadded_dim_size=seq_len)
         | 
| 178 | 
            +
                q, k, v = qkv.split(pad_inner_dim // sp_size, dim=2)
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                pad_n = pad_inner_dim // d
         | 
| 181 | 
            +
                pad_split_n = pad_n // sp_size
         | 
| 182 | 
            +
                q = q.view(b, seq_len, pad_split_n, d)
         | 
| 183 | 
            +
                k = k.view(b, seq_len, pad_split_n, d)
         | 
| 184 | 
            +
                v = v.view(b, seq_len, pad_split_n, d)
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                q = rope_apply(q, grid_sizes, freqs)
         | 
| 187 | 
            +
                k = rope_apply(k, grid_sizes, freqs)
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                x = flash_attention(
         | 
| 190 | 
            +
                    q=half(q),
         | 
| 191 | 
            +
                    k=half(k),
         | 
| 192 | 
            +
                    v=half(v),
         | 
| 193 | 
            +
                    k_lens=seq_lens,
         | 
| 194 | 
            +
                    window_size=self.window_size
         | 
| 195 | 
            +
                )
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                # ulysses support
         | 
| 198 | 
            +
                x = x.flatten(2)
         | 
| 199 | 
            +
                x = gather_heads_scatter_seq(x, head_dim=2, seq_dim=1)
         | 
| 200 | 
            +
                if n % sp_size:
         | 
| 201 | 
            +
                    x = unpad_tensor(x, dim=2, unpad_dim_size=seq_len)
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                x = self.o(x)
         | 
| 204 | 
            +
                return x
         | 
| 205 | 
            +
             | 
| 206 | 
            +
             | 
| 207 | 
            +
            def ulysses_audio_cross_attn_forward(
         | 
| 208 | 
            +
                self,
         | 
| 209 | 
            +
                x,
         | 
| 210 | 
            +
                audio,
         | 
| 211 | 
            +
                seq_lens,
         | 
| 212 | 
            +
                grid_sizes,
         | 
| 213 | 
            +
                freqs,
         | 
| 214 | 
            +
                audio_seq_len,
         | 
| 215 | 
            +
                dtype=torch.bfloat16
         | 
| 216 | 
            +
            ):
         | 
| 217 | 
            +
                b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
         | 
| 218 | 
            +
                seq_len = seq_lens.max()
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                q = self.norm_q(self.q(x))
         | 
| 221 | 
            +
                k = self.norm_k(self.k(audio))
         | 
| 222 | 
            +
                v = self.v(audio)
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                # ulysses support
         | 
| 225 | 
            +
                sp_size = get_unified_parallel_world_size()
         | 
| 226 | 
            +
                if n % sp_size:
         | 
| 227 | 
            +
                    pad_size = sp_size - (n % sp_size)
         | 
| 228 | 
            +
                    pad_size = pad_size * d
         | 
| 229 | 
            +
                    pad_inner_dim = n * d + pad_size
         | 
| 230 | 
            +
                    q = pad_tensor(q, dim=2, padding_size=pad_size)
         | 
| 231 | 
            +
                    k = pad_tensor(k, dim=2, padding_size=pad_size)
         | 
| 232 | 
            +
                    v = pad_tensor(v, dim=2, padding_size=pad_size)
         | 
| 233 | 
            +
                else:
         | 
| 234 | 
            +
                    pad_inner_dim = n * d
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                qq = torch.cat([q, q], dim=2)
         | 
| 237 | 
            +
                kv = torch.cat([k, v], dim=2)
         | 
| 238 | 
            +
                qq = gather_seq_scatter_double_head(qq, seq_dim=1, unpadded_dim_size=seq_len)
         | 
| 239 | 
            +
                kv = gather_seq_scatter_double_head(kv, seq_dim=1, unpadded_dim_size=audio_seq_len)
         | 
| 240 | 
            +
                q, _ = qq.split(pad_inner_dim // sp_size, dim=2)
         | 
| 241 | 
            +
                k, v = kv.split(pad_inner_dim // sp_size, dim=2)
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                pad_n = pad_inner_dim // d
         | 
| 244 | 
            +
                pad_split_n = pad_n // sp_size
         | 
| 245 | 
            +
                q = q.view(b, seq_len, pad_split_n, d)
         | 
| 246 | 
            +
                k = k.view(b, audio_seq_len, pad_split_n, d)
         | 
| 247 | 
            +
                v = v.view(b, audio_seq_len, pad_split_n, d)
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                hlen_wlen = int(grid_sizes[0][1] * grid_sizes[0][2])
         | 
| 250 | 
            +
                assert hlen_wlen == 1560 or hlen_wlen == 3600
         | 
| 251 | 
            +
                q = q.reshape(-1, hlen_wlen, pad_split_n, d)
         | 
| 252 | 
            +
                k = k.reshape(-1, 16, pad_split_n, d)
         | 
| 253 | 
            +
                v = v.reshape(-1, 16, pad_split_n, d)
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                x = flash_attention(
         | 
| 256 | 
            +
                    q=q,
         | 
| 257 | 
            +
                    k=k,
         | 
| 258 | 
            +
                    v=v,
         | 
| 259 | 
            +
                    k_lens=None,
         | 
| 260 | 
            +
                )
         | 
| 261 | 
            +
                x = x.view(b, -1, pad_split_n, d)
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                # ulysses support
         | 
| 264 | 
            +
                x = x.flatten(2)
         | 
| 265 | 
            +
                x = gather_heads_scatter_seq(x, head_dim=2, seq_dim=1)
         | 
| 266 | 
            +
                if n % sp_size:
         | 
| 267 | 
            +
                    x = unpad_tensor(x, dim=2, unpad_dim_size=seq_len)
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                x = self.o(x)
         | 
| 270 | 
            +
                return x
         | 
    	
        humo/models/distributed/fsdp.py
    ADDED
    
    | @@ -0,0 +1,42 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
         | 
| 2 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 3 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 4 | 
            +
            # You may obtain a copy of the License at
         | 
| 5 | 
            +
            # http://www.apache.org/licenses/LICENSE-2.0
         | 
| 6 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 7 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 8 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 9 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 10 | 
            +
            # limitations under the License.
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            from functools import partial
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            import torch
         | 
| 15 | 
            +
            from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
         | 
| 16 | 
            +
            from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
         | 
| 17 | 
            +
            from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            def shard_model(
         | 
| 21 | 
            +
                model,
         | 
| 22 | 
            +
                device_id,
         | 
| 23 | 
            +
                param_dtype=torch.bfloat16,
         | 
| 24 | 
            +
                reduce_dtype=torch.float32,
         | 
| 25 | 
            +
                buffer_dtype=torch.float32,
         | 
| 26 | 
            +
                process_group=None,
         | 
| 27 | 
            +
                sharding_strategy=ShardingStrategy.FULL_SHARD,
         | 
| 28 | 
            +
                sync_module_states=True,
         | 
| 29 | 
            +
            ):
         | 
| 30 | 
            +
                model = FSDP(
         | 
| 31 | 
            +
                    module=model,
         | 
| 32 | 
            +
                    process_group=process_group,
         | 
| 33 | 
            +
                    sharding_strategy=sharding_strategy,
         | 
| 34 | 
            +
                    auto_wrap_policy=partial(
         | 
| 35 | 
            +
                        lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks),
         | 
| 36 | 
            +
                    mixed_precision=MixedPrecision(
         | 
| 37 | 
            +
                        param_dtype=param_dtype,
         | 
| 38 | 
            +
                        reduce_dtype=reduce_dtype,
         | 
| 39 | 
            +
                        buffer_dtype=buffer_dtype),
         | 
| 40 | 
            +
                    device_id=device_id,
         | 
| 41 | 
            +
                    sync_module_states=sync_module_states)
         | 
| 42 | 
            +
                return model
         | 
    	
        humo/models/text/encoder.py
    ADDED
    
    | @@ -0,0 +1,173 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            from dataclasses import dataclass
         | 
| 3 | 
            +
            from typing import List, Optional, Union
         | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            from omegaconf import DictConfig, OmegaConf
         | 
| 6 | 
            +
            from torch import nn
         | 
| 7 | 
            +
            from transformers import (
         | 
| 8 | 
            +
                AutoModelForCausalLM,
         | 
| 9 | 
            +
                AutoTokenizer,
         | 
| 10 | 
            +
                CLIPTextModel,
         | 
| 11 | 
            +
                CLIPTokenizerFast,
         | 
| 12 | 
            +
                T5EncoderModel,
         | 
| 13 | 
            +
                T5TokenizerFast,
         | 
| 14 | 
            +
            )
         | 
| 15 | 
            +
            from transformers.tokenization_utils_base import BatchEncoding
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            from common.fs import download_and_extract
         | 
| 18 | 
            +
            from common.logger import get_logger
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            logger = get_logger(__name__)
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            MODEL_TYPES = {
         | 
| 23 | 
            +
                "clip": (CLIPTokenizerFast, CLIPTextModel),
         | 
| 24 | 
            +
                "t5": (T5TokenizerFast, T5EncoderModel),
         | 
| 25 | 
            +
                "llm14b": (AutoTokenizer, AutoModelForCausalLM),
         | 
| 26 | 
            +
            }
         | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
            @dataclass
         | 
| 30 | 
            +
            class TextEncoderOutput:
         | 
| 31 | 
            +
                embeddings: Union[torch.FloatTensor, List[torch.FloatTensor]]
         | 
| 32 | 
            +
                masks: Union[torch.BoolTensor, List[torch.BoolTensor]]
         | 
| 33 | 
            +
                pooled: Optional[Union[torch.FloatTensor, List[torch.FloatTensor]]]
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
            class TextEncoder(nn.Module):
         | 
| 37 | 
            +
                def __init__(self, config: DictConfig):
         | 
| 38 | 
            +
                    super().__init__()
         | 
| 39 | 
            +
                    self.config = config
         | 
| 40 | 
            +
                    self.tokenizers = []
         | 
| 41 | 
            +
                    self.models = nn.ModuleList([])
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                    # Disable tokenizer parallelism since we already use distributed training.
         | 
| 44 | 
            +
                    os.environ["TOKENIZERS_PARALLELISM"] = "false"
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                    for model in config.models:
         | 
| 47 | 
            +
                        tokenizer_cls, model_cls = MODEL_TYPES[model.type]
         | 
| 48 | 
            +
                        path = download_and_extract(model.path)
         | 
| 49 | 
            +
                        max_length = model.max_length
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                        if model.type == "llm14b":
         | 
| 52 | 
            +
                            tokenizer = tokenizer_cls.from_pretrained(
         | 
| 53 | 
            +
                                path,
         | 
| 54 | 
            +
                                model_max_length=max_length,
         | 
| 55 | 
            +
                                use_fast=False,
         | 
| 56 | 
            +
                                trust_remote_code=True,
         | 
| 57 | 
            +
                                padding_side="right",
         | 
| 58 | 
            +
                                truncation_side="right",
         | 
| 59 | 
            +
                                add_eod_token=True,
         | 
| 60 | 
            +
                            )
         | 
| 61 | 
            +
                            tokenizer.add_special_tokens({"pad_token": "<|endoftext|>"})
         | 
| 62 | 
            +
                            model = model_cls.from_pretrained(path, trust_remote_code=True, bf16=True)
         | 
| 63 | 
            +
                        else:
         | 
| 64 | 
            +
                            tokenizer = tokenizer_cls.from_pretrained(path, model_max_length=max_length)
         | 
| 65 | 
            +
                            model = model_cls.from_pretrained(path, torch_dtype=torch.bfloat16)
         | 
| 66 | 
            +
                        self.tokenizers.append(tokenizer)
         | 
| 67 | 
            +
                        self.models.append(model)
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                def forward(self, text: Union[str, List[str]]) -> TextEncoderOutput:
         | 
| 70 | 
            +
                    embeddings, masks, pooled = [], [], []
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                    for encoder_config, tokenizer, model in zip(
         | 
| 73 | 
            +
                        self.config.models, self.tokenizers, self.models
         | 
| 74 | 
            +
                    ):
         | 
| 75 | 
            +
                        if encoder_config.type == "llm14b":
         | 
| 76 | 
            +
                            use_mask = encoder_config.get("mask", True)
         | 
| 77 | 
            +
                            tokens = tokenizer(
         | 
| 78 | 
            +
                                text,
         | 
| 79 | 
            +
                                return_tensors="pt",
         | 
| 80 | 
            +
                                padding="max_length",
         | 
| 81 | 
            +
                                truncation=True,
         | 
| 82 | 
            +
                            ).to(model.device)
         | 
| 83 | 
            +
                            token_ids = tokens["input_ids"]
         | 
| 84 | 
            +
                            attention_mask = tokens["attention_mask"]
         | 
| 85 | 
            +
                            num_tokens = attention_mask.sum(dim=1)
         | 
| 86 | 
            +
                            range_ids = torch.arange(len(token_ids), device=token_ids.device, dtype=torch.long)
         | 
| 87 | 
            +
                            token_ids[range_ids, num_tokens.clamp(max=token_ids.size(1) - 1)] = (
         | 
| 88 | 
            +
                                tokenizer.pad_token_id
         | 
| 89 | 
            +
                            )
         | 
| 90 | 
            +
                            attention_mask[range_ids, num_tokens.clamp(max=token_ids.size(1) - 1)] = 1
         | 
| 91 | 
            +
                            tokens = BatchEncoding({"input_ids": token_ids, "attention_mask": attention_mask})
         | 
| 92 | 
            +
                            output = model.transformer(
         | 
| 93 | 
            +
                                input_ids=tokens.input_ids,
         | 
| 94 | 
            +
                                attention_mask=attention_mask if use_mask else None,
         | 
| 95 | 
            +
                                output_hidden_states=False,
         | 
| 96 | 
            +
                                use_cache=False,
         | 
| 97 | 
            +
                            )
         | 
| 98 | 
            +
                            emb = output.last_hidden_state  # batch_size, num_tokens, feat_dim
         | 
| 99 | 
            +
                            # emb *= tokens.attention_mask.unsqueeze(-1)
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                            embeddings.append(emb)
         | 
| 102 | 
            +
                            masks.append(
         | 
| 103 | 
            +
                                tokens.attention_mask.bool() if use_mask else tokens.attention_mask > -1
         | 
| 104 | 
            +
                            )
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                        else:
         | 
| 107 | 
            +
                            # Tokenizer
         | 
| 108 | 
            +
                            tokens = tokenizer(
         | 
| 109 | 
            +
                                text=text,
         | 
| 110 | 
            +
                                truncation=True,
         | 
| 111 | 
            +
                                padding="max_length",
         | 
| 112 | 
            +
                                return_tensors="pt",
         | 
| 113 | 
            +
                            )
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                            # Encoder
         | 
| 116 | 
            +
                            use_mask = encoder_config.get("mask", True)
         | 
| 117 | 
            +
                            input_ids = tokens.input_ids.to(model.device)
         | 
| 118 | 
            +
                            attention_mask = tokens.attention_mask.to(model.device)
         | 
| 119 | 
            +
                            output = model(
         | 
| 120 | 
            +
                                input_ids=input_ids,
         | 
| 121 | 
            +
                                attention_mask=attention_mask if use_mask else None,
         | 
| 122 | 
            +
                                output_hidden_states=True,
         | 
| 123 | 
            +
                            )
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                            # Save embeddings from the defined layer.
         | 
| 126 | 
            +
                            layer = encoder_config.get("layer", "last")
         | 
| 127 | 
            +
                            if layer == "last":
         | 
| 128 | 
            +
                                embeddings.append(output.last_hidden_state)
         | 
| 129 | 
            +
                            elif layer == "penultimate":
         | 
| 130 | 
            +
                                embeddings.append(model.text_model.final_layer_norm(output.hidden_states[-2]))
         | 
| 131 | 
            +
                            elif layer == "penultimate_nonorm":
         | 
| 132 | 
            +
                                embeddings.append(output.hidden_states[-2])
         | 
| 133 | 
            +
                            else:
         | 
| 134 | 
            +
                                raise NotImplementedError(f"Unknown layer type: {layer}.")
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                            # Save masks
         | 
| 137 | 
            +
                            masks.append(attention_mask.bool() if use_mask else attention_mask > -1)
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                            # Save pooled output if available.
         | 
| 140 | 
            +
                            if hasattr(output, "pooler_output"):
         | 
| 141 | 
            +
                                pooled.append(output.pooler_output)
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                        output_config = self.config.get("output") or OmegaConf.create()
         | 
| 144 | 
            +
                        embedding_output_type = output_config.get("embedding_and_mask", "undefined")
         | 
| 145 | 
            +
                        pooled_output_type = output_config.get("pooled", "undefined")
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                        # Select or merge embeddings and mask if needed.
         | 
| 148 | 
            +
                        if embedding_output_type == "undefined" and len(self.models) == 1:
         | 
| 149 | 
            +
                            embeddings = embeddings[0]
         | 
| 150 | 
            +
                            masks = masks[0]
         | 
| 151 | 
            +
                        elif embedding_output_type == "channel_concat":
         | 
| 152 | 
            +
                            embeddings = torch.cat(embeddings, dim=-1)
         | 
| 153 | 
            +
                            masks = sum(masks).bool()
         | 
| 154 | 
            +
                        elif embedding_output_type == "last":
         | 
| 155 | 
            +
                            embeddings = embeddings[-1]
         | 
| 156 | 
            +
                            masks = masks[-1]
         | 
| 157 | 
            +
                        else:
         | 
| 158 | 
            +
                            raise NotImplementedError(f"output.embedding_and_mask: {embedding_output_type}")
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                        # Select or merge pooled output if needed.
         | 
| 161 | 
            +
                        if pooled_output_type == "undefined":
         | 
| 162 | 
            +
                            pooled = None
         | 
| 163 | 
            +
                        elif pooled_output_type == "channel_concat":
         | 
| 164 | 
            +
                            pooled = torch.cat(pooled, dim=-1)
         | 
| 165 | 
            +
                        elif pooled_output_type == "first":
         | 
| 166 | 
            +
                            pooled = pooled[0]
         | 
| 167 | 
            +
                        elif pooled_output_type == "last":
         | 
| 168 | 
            +
                            pooled = pooled[-1]
         | 
| 169 | 
            +
                        else:
         | 
| 170 | 
            +
                            raise NotImplementedError(f"output.pooled: {pooled_output_type}")
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                    # Return final results.
         | 
| 173 | 
            +
                    return TextEncoderOutput(embeddings, masks, pooled)
         | 
    	
        humo/models/utils/fm_solvers.py
    ADDED
    
    | @@ -0,0 +1,857 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
         | 
| 2 | 
            +
            # Convert dpm solver for flow matching
         | 
| 3 | 
            +
            # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import inspect
         | 
| 6 | 
            +
            import math
         | 
| 7 | 
            +
            from typing import List, Optional, Tuple, Union
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import numpy as np
         | 
| 10 | 
            +
            import torch
         | 
| 11 | 
            +
            from diffusers.configuration_utils import ConfigMixin, register_to_config
         | 
| 12 | 
            +
            from diffusers.schedulers.scheduling_utils import (KarrasDiffusionSchedulers,
         | 
| 13 | 
            +
                                                               SchedulerMixin,
         | 
| 14 | 
            +
                                                               SchedulerOutput)
         | 
| 15 | 
            +
            from diffusers.utils import deprecate, is_scipy_available
         | 
| 16 | 
            +
            from diffusers.utils.torch_utils import randn_tensor
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            if is_scipy_available():
         | 
| 19 | 
            +
                pass
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            def get_sampling_sigmas(sampling_steps, shift):
         | 
| 23 | 
            +
                sigma = np.linspace(1, 0, sampling_steps + 1)[:sampling_steps]
         | 
| 24 | 
            +
                sigma = (shift * sigma / (1 + (shift - 1) * sigma))
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                return sigma
         | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
            def retrieve_timesteps(
         | 
| 30 | 
            +
                scheduler,
         | 
| 31 | 
            +
                num_inference_steps=None,
         | 
| 32 | 
            +
                device=None,
         | 
| 33 | 
            +
                timesteps=None,
         | 
| 34 | 
            +
                sigmas=None,
         | 
| 35 | 
            +
                **kwargs,
         | 
| 36 | 
            +
            ):
         | 
| 37 | 
            +
                if timesteps is not None and sigmas is not None:
         | 
| 38 | 
            +
                    raise ValueError(
         | 
| 39 | 
            +
                        "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
         | 
| 40 | 
            +
                    )
         | 
| 41 | 
            +
                if timesteps is not None:
         | 
| 42 | 
            +
                    accepts_timesteps = "timesteps" in set(
         | 
| 43 | 
            +
                        inspect.signature(scheduler.set_timesteps).parameters.keys())
         | 
| 44 | 
            +
                    if not accepts_timesteps:
         | 
| 45 | 
            +
                        raise ValueError(
         | 
| 46 | 
            +
                            f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
         | 
| 47 | 
            +
                            f" timestep schedules. Please check whether you are using the correct scheduler."
         | 
| 48 | 
            +
                        )
         | 
| 49 | 
            +
                    scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
         | 
| 50 | 
            +
                    timesteps = scheduler.timesteps
         | 
| 51 | 
            +
                    num_inference_steps = len(timesteps)
         | 
| 52 | 
            +
                elif sigmas is not None:
         | 
| 53 | 
            +
                    accept_sigmas = "sigmas" in set(
         | 
| 54 | 
            +
                        inspect.signature(scheduler.set_timesteps).parameters.keys())
         | 
| 55 | 
            +
                    if not accept_sigmas:
         | 
| 56 | 
            +
                        raise ValueError(
         | 
| 57 | 
            +
                            f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
         | 
| 58 | 
            +
                            f" sigmas schedules. Please check whether you are using the correct scheduler."
         | 
| 59 | 
            +
                        )
         | 
| 60 | 
            +
                    scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
         | 
| 61 | 
            +
                    timesteps = scheduler.timesteps
         | 
| 62 | 
            +
                    num_inference_steps = len(timesteps)
         | 
| 63 | 
            +
                else:
         | 
| 64 | 
            +
                    scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
         | 
| 65 | 
            +
                    timesteps = scheduler.timesteps
         | 
| 66 | 
            +
                return timesteps, num_inference_steps
         | 
| 67 | 
            +
             | 
| 68 | 
            +
             | 
| 69 | 
            +
            class FlowDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
         | 
| 70 | 
            +
                """
         | 
| 71 | 
            +
                `FlowDPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs.
         | 
| 72 | 
            +
                This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
         | 
| 73 | 
            +
                methods the library implements for all schedulers such as loading and saving.
         | 
| 74 | 
            +
                Args:
         | 
| 75 | 
            +
                    num_train_timesteps (`int`, defaults to 1000):
         | 
| 76 | 
            +
                        The number of diffusion steps to train the model. This determines the resolution of the diffusion process.
         | 
| 77 | 
            +
                    solver_order (`int`, defaults to 2):
         | 
| 78 | 
            +
                        The DPMSolver order which can be `1`, `2`, or `3`. It is recommended to use `solver_order=2` for guided
         | 
| 79 | 
            +
                        sampling, and `solver_order=3` for unconditional sampling. This affects the number of model outputs stored
         | 
| 80 | 
            +
                        and used in multistep updates.
         | 
| 81 | 
            +
                    prediction_type (`str`, defaults to "flow_prediction"):
         | 
| 82 | 
            +
                        Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts
         | 
| 83 | 
            +
                        the flow of the diffusion process.
         | 
| 84 | 
            +
                    shift (`float`, *optional*, defaults to 1.0):
         | 
| 85 | 
            +
                        A factor used to adjust the sigmas in the noise schedule. It modifies the step sizes during the sampling
         | 
| 86 | 
            +
                        process.
         | 
| 87 | 
            +
                    use_dynamic_shifting (`bool`, defaults to `False`):
         | 
| 88 | 
            +
                        Whether to apply dynamic shifting to the timesteps based on image resolution. If `True`, the shifting is
         | 
| 89 | 
            +
                        applied on the fly.
         | 
| 90 | 
            +
                    thresholding (`bool`, defaults to `False`):
         | 
| 91 | 
            +
                        Whether to use the "dynamic thresholding" method. This method adjusts the predicted sample to prevent
         | 
| 92 | 
            +
                        saturation and improve photorealism.
         | 
| 93 | 
            +
                    dynamic_thresholding_ratio (`float`, defaults to 0.995):
         | 
| 94 | 
            +
                        The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
         | 
| 95 | 
            +
                    sample_max_value (`float`, defaults to 1.0):
         | 
| 96 | 
            +
                        The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
         | 
| 97 | 
            +
                        `algorithm_type="dpmsolver++"`.
         | 
| 98 | 
            +
                    algorithm_type (`str`, defaults to `dpmsolver++`):
         | 
| 99 | 
            +
                        Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The
         | 
| 100 | 
            +
                        `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927)
         | 
| 101 | 
            +
                        paper, and the `dpmsolver++` type implements the algorithms in the
         | 
| 102 | 
            +
                        [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or
         | 
| 103 | 
            +
                        `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
         | 
| 104 | 
            +
                    solver_type (`str`, defaults to `midpoint`):
         | 
| 105 | 
            +
                        Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
         | 
| 106 | 
            +
                        sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
         | 
| 107 | 
            +
                    lower_order_final (`bool`, defaults to `True`):
         | 
| 108 | 
            +
                        Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
         | 
| 109 | 
            +
                        stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
         | 
| 110 | 
            +
                    euler_at_final (`bool`, defaults to `False`):
         | 
| 111 | 
            +
                        Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail
         | 
| 112 | 
            +
                        richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference
         | 
| 113 | 
            +
                        steps, but sometimes may result in blurring.
         | 
| 114 | 
            +
                    final_sigmas_type (`str`, *optional*, defaults to "zero"):
         | 
| 115 | 
            +
                        The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
         | 
| 116 | 
            +
                        sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
         | 
| 117 | 
            +
                    lambda_min_clipped (`float`, defaults to `-inf`):
         | 
| 118 | 
            +
                        Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
         | 
| 119 | 
            +
                        cosine (`squaredcos_cap_v2`) noise schedule.
         | 
| 120 | 
            +
                    variance_type (`str`, *optional*):
         | 
| 121 | 
            +
                        Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output
         | 
| 122 | 
            +
                        contains the predicted Gaussian variance.
         | 
| 123 | 
            +
                """
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                _compatibles = [e.name for e in KarrasDiffusionSchedulers]
         | 
| 126 | 
            +
                order = 1
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                @register_to_config
         | 
| 129 | 
            +
                def __init__(
         | 
| 130 | 
            +
                    self,
         | 
| 131 | 
            +
                    num_train_timesteps: int = 1000,
         | 
| 132 | 
            +
                    solver_order: int = 2,
         | 
| 133 | 
            +
                    prediction_type: str = "flow_prediction",
         | 
| 134 | 
            +
                    shift: Optional[float] = 1.0,
         | 
| 135 | 
            +
                    use_dynamic_shifting=False,
         | 
| 136 | 
            +
                    thresholding: bool = False,
         | 
| 137 | 
            +
                    dynamic_thresholding_ratio: float = 0.995,
         | 
| 138 | 
            +
                    sample_max_value: float = 1.0,
         | 
| 139 | 
            +
                    algorithm_type: str = "dpmsolver++",
         | 
| 140 | 
            +
                    solver_type: str = "midpoint",
         | 
| 141 | 
            +
                    lower_order_final: bool = True,
         | 
| 142 | 
            +
                    euler_at_final: bool = False,
         | 
| 143 | 
            +
                    final_sigmas_type: Optional[str] = "zero",  # "zero", "sigma_min"
         | 
| 144 | 
            +
                    lambda_min_clipped: float = -float("inf"),
         | 
| 145 | 
            +
                    variance_type: Optional[str] = None,
         | 
| 146 | 
            +
                    invert_sigmas: bool = False,
         | 
| 147 | 
            +
                ):
         | 
| 148 | 
            +
                    if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
         | 
| 149 | 
            +
                        deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
         | 
| 150 | 
            +
                        deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0",
         | 
| 151 | 
            +
                                  deprecation_message)
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                    # settings for DPM-Solver
         | 
| 154 | 
            +
                    if algorithm_type not in [
         | 
| 155 | 
            +
                            "dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"
         | 
| 156 | 
            +
                    ]:
         | 
| 157 | 
            +
                        if algorithm_type == "deis":
         | 
| 158 | 
            +
                            self.register_to_config(algorithm_type="dpmsolver++")
         | 
| 159 | 
            +
                        else:
         | 
| 160 | 
            +
                            raise NotImplementedError(
         | 
| 161 | 
            +
                                f"{algorithm_type} is not implemented for {self.__class__}")
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                    if solver_type not in ["midpoint", "heun"]:
         | 
| 164 | 
            +
                        if solver_type in ["logrho", "bh1", "bh2"]:
         | 
| 165 | 
            +
                            self.register_to_config(solver_type="midpoint")
         | 
| 166 | 
            +
                        else:
         | 
| 167 | 
            +
                            raise NotImplementedError(
         | 
| 168 | 
            +
                                f"{solver_type} is not implemented for {self.__class__}")
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                    if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"
         | 
| 171 | 
            +
                                             ] and final_sigmas_type == "zero":
         | 
| 172 | 
            +
                        raise ValueError(
         | 
| 173 | 
            +
                            f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead."
         | 
| 174 | 
            +
                        )
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                    # setable values
         | 
| 177 | 
            +
                    self.num_inference_steps = None
         | 
| 178 | 
            +
                    alphas = np.linspace(1, 1 / num_train_timesteps,
         | 
| 179 | 
            +
                                         num_train_timesteps)[::-1].copy()
         | 
| 180 | 
            +
                    sigmas = 1.0 - alphas
         | 
| 181 | 
            +
                    sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32)
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                    if not use_dynamic_shifting:
         | 
| 184 | 
            +
                        # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
         | 
| 185 | 
            +
                        sigmas = shift * sigmas / (1 +
         | 
| 186 | 
            +
                                                   (shift - 1) * sigmas)  # pyright: ignore
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                    self.sigmas = sigmas
         | 
| 189 | 
            +
                    self.timesteps = sigmas * num_train_timesteps
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                    self.model_outputs = [None] * solver_order
         | 
| 192 | 
            +
                    self.lower_order_nums = 0
         | 
| 193 | 
            +
                    self._step_index = None
         | 
| 194 | 
            +
                    self._begin_index = None
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                    # self.sigmas = self.sigmas.to(
         | 
| 197 | 
            +
                    #     "cpu")  # to avoid too much CPU/GPU communication
         | 
| 198 | 
            +
                    self.sigma_min = self.sigmas[-1].item()
         | 
| 199 | 
            +
                    self.sigma_max = self.sigmas[0].item()
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                @property
         | 
| 202 | 
            +
                def step_index(self):
         | 
| 203 | 
            +
                    """
         | 
| 204 | 
            +
                    The index counter for current timestep. It will increase 1 after each scheduler step.
         | 
| 205 | 
            +
                    """
         | 
| 206 | 
            +
                    return self._step_index
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                @property
         | 
| 209 | 
            +
                def begin_index(self):
         | 
| 210 | 
            +
                    """
         | 
| 211 | 
            +
                    The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
         | 
| 212 | 
            +
                    """
         | 
| 213 | 
            +
                    return self._begin_index
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
         | 
| 216 | 
            +
                def set_begin_index(self, begin_index: int = 0):
         | 
| 217 | 
            +
                    """
         | 
| 218 | 
            +
                    Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
         | 
| 219 | 
            +
                    Args:
         | 
| 220 | 
            +
                        begin_index (`int`):
         | 
| 221 | 
            +
                            The begin index for the scheduler.
         | 
| 222 | 
            +
                    """
         | 
| 223 | 
            +
                    self._begin_index = begin_index
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps
         | 
| 226 | 
            +
                def set_timesteps(
         | 
| 227 | 
            +
                    self,
         | 
| 228 | 
            +
                    num_inference_steps: Union[int, None] = None,
         | 
| 229 | 
            +
                    device: Union[str, torch.device] = None,
         | 
| 230 | 
            +
                    sigmas: Optional[List[float]] = None,
         | 
| 231 | 
            +
                    mu: Optional[Union[float, None]] = None,
         | 
| 232 | 
            +
                    shift: Optional[Union[float, None]] = None,
         | 
| 233 | 
            +
                ):
         | 
| 234 | 
            +
                    """
         | 
| 235 | 
            +
                    Sets the discrete timesteps used for the diffusion chain (to be run before inference).
         | 
| 236 | 
            +
                    Args:
         | 
| 237 | 
            +
                        num_inference_steps (`int`):
         | 
| 238 | 
            +
                            Total number of the spacing of the time steps.
         | 
| 239 | 
            +
                        device (`str` or `torch.device`, *optional*):
         | 
| 240 | 
            +
                            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
         | 
| 241 | 
            +
                    """
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                    if self.config.use_dynamic_shifting and mu is None:
         | 
| 244 | 
            +
                        raise ValueError(
         | 
| 245 | 
            +
                            " you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`"
         | 
| 246 | 
            +
                        )
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                    if sigmas is None:
         | 
| 249 | 
            +
                        sigmas = np.linspace(self.sigma_max, self.sigma_min,
         | 
| 250 | 
            +
                                             num_inference_steps +
         | 
| 251 | 
            +
                                             1).copy()[:-1]  # pyright: ignore
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                    if self.config.use_dynamic_shifting:
         | 
| 254 | 
            +
                        sigmas = self.time_shift(mu, 1.0, sigmas)  # pyright: ignore
         | 
| 255 | 
            +
                    else:
         | 
| 256 | 
            +
                        if shift is None:
         | 
| 257 | 
            +
                            shift = self.config.shift
         | 
| 258 | 
            +
                        sigmas = shift * sigmas / (1 +
         | 
| 259 | 
            +
                                                   (shift - 1) * sigmas)  # pyright: ignore
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                    if self.config.final_sigmas_type == "sigma_min":
         | 
| 262 | 
            +
                        sigma_last = ((1 - self.alphas_cumprod[0]) /
         | 
| 263 | 
            +
                                      self.alphas_cumprod[0])**0.5
         | 
| 264 | 
            +
                    elif self.config.final_sigmas_type == "zero":
         | 
| 265 | 
            +
                        sigma_last = 0
         | 
| 266 | 
            +
                    else:
         | 
| 267 | 
            +
                        raise ValueError(
         | 
| 268 | 
            +
                            f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
         | 
| 269 | 
            +
                        )
         | 
| 270 | 
            +
             | 
| 271 | 
            +
                    timesteps = sigmas * self.config.num_train_timesteps
         | 
| 272 | 
            +
                    sigmas = np.concatenate([sigmas, [sigma_last]
         | 
| 273 | 
            +
                                            ]).astype(np.float32)  # pyright: ignore
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                    self.sigmas = torch.from_numpy(sigmas)
         | 
| 276 | 
            +
                    self.timesteps = torch.from_numpy(timesteps).to(
         | 
| 277 | 
            +
                        device=device, dtype=torch.int64)
         | 
| 278 | 
            +
             | 
| 279 | 
            +
                    self.num_inference_steps = len(timesteps)
         | 
| 280 | 
            +
             | 
| 281 | 
            +
                    self.model_outputs = [
         | 
| 282 | 
            +
                        None,
         | 
| 283 | 
            +
                    ] * self.config.solver_order
         | 
| 284 | 
            +
                    self.lower_order_nums = 0
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                    self._step_index = None
         | 
| 287 | 
            +
                    self._begin_index = None
         | 
| 288 | 
            +
                    # self.sigmas = self.sigmas.to(
         | 
| 289 | 
            +
                    #     "cpu")  # to avoid too much CPU/GPU communication
         | 
| 290 | 
            +
             | 
| 291 | 
            +
                # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
         | 
| 292 | 
            +
                def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
         | 
| 293 | 
            +
                    """
         | 
| 294 | 
            +
                    "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
         | 
| 295 | 
            +
                    prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
         | 
| 296 | 
            +
                    s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
         | 
| 297 | 
            +
                    pixels from saturation at each step. We find that dynamic thresholding results in significantly better
         | 
| 298 | 
            +
                    photorealism as well as better image-text alignment, especially when using very large guidance weights."
         | 
| 299 | 
            +
                    https://arxiv.org/abs/2205.11487
         | 
| 300 | 
            +
                    """
         | 
| 301 | 
            +
                    dtype = sample.dtype
         | 
| 302 | 
            +
                    batch_size, channels, *remaining_dims = sample.shape
         | 
| 303 | 
            +
             | 
| 304 | 
            +
                    if dtype not in (torch.float32, torch.float64):
         | 
| 305 | 
            +
                        sample = sample.float(
         | 
| 306 | 
            +
                        )  # upcast for quantile calculation, and clamp not implemented for cpu half
         | 
| 307 | 
            +
             | 
| 308 | 
            +
                    # Flatten sample for doing quantile calculation along each image
         | 
| 309 | 
            +
                    sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
         | 
| 310 | 
            +
             | 
| 311 | 
            +
                    abs_sample = sample.abs()  # "a certain percentile absolute pixel value"
         | 
| 312 | 
            +
             | 
| 313 | 
            +
                    s = torch.quantile(
         | 
| 314 | 
            +
                        abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
         | 
| 315 | 
            +
                    s = torch.clamp(
         | 
| 316 | 
            +
                        s, min=1, max=self.config.sample_max_value
         | 
| 317 | 
            +
                    )  # When clamped to min=1, equivalent to standard clipping to [-1, 1]
         | 
| 318 | 
            +
                    s = s.unsqueeze(
         | 
| 319 | 
            +
                        1)  # (batch_size, 1) because clamp will broadcast along dim=0
         | 
| 320 | 
            +
                    sample = torch.clamp(
         | 
| 321 | 
            +
                        sample, -s, s
         | 
| 322 | 
            +
                    ) / s  # "we threshold xt0 to the range [-s, s] and then divide by s"
         | 
| 323 | 
            +
             | 
| 324 | 
            +
                    sample = sample.reshape(batch_size, channels, *remaining_dims)
         | 
| 325 | 
            +
                    sample = sample.to(dtype)
         | 
| 326 | 
            +
             | 
| 327 | 
            +
                    return sample
         | 
| 328 | 
            +
             | 
| 329 | 
            +
                # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t
         | 
| 330 | 
            +
                def _sigma_to_t(self, sigma):
         | 
| 331 | 
            +
                    return sigma * self.config.num_train_timesteps
         | 
| 332 | 
            +
             | 
| 333 | 
            +
                def _sigma_to_alpha_sigma_t(self, sigma):
         | 
| 334 | 
            +
                    return 1 - sigma, sigma
         | 
| 335 | 
            +
             | 
| 336 | 
            +
                # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps
         | 
| 337 | 
            +
                def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
         | 
| 338 | 
            +
                    return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma)
         | 
| 339 | 
            +
             | 
| 340 | 
            +
                # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output
         | 
| 341 | 
            +
                def convert_model_output(
         | 
| 342 | 
            +
                    self,
         | 
| 343 | 
            +
                    model_output: torch.Tensor,
         | 
| 344 | 
            +
                    *args,
         | 
| 345 | 
            +
                    sample: torch.Tensor = None,
         | 
| 346 | 
            +
                    **kwargs,
         | 
| 347 | 
            +
                ) -> torch.Tensor:
         | 
| 348 | 
            +
                    """
         | 
| 349 | 
            +
                    Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
         | 
| 350 | 
            +
                    designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
         | 
| 351 | 
            +
                    integral of the data prediction model.
         | 
| 352 | 
            +
                    <Tip>
         | 
| 353 | 
            +
                    The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
         | 
| 354 | 
            +
                    prediction and data prediction models.
         | 
| 355 | 
            +
                    </Tip>
         | 
| 356 | 
            +
                    Args:
         | 
| 357 | 
            +
                        model_output (`torch.Tensor`):
         | 
| 358 | 
            +
                            The direct output from the learned diffusion model.
         | 
| 359 | 
            +
                        sample (`torch.Tensor`):
         | 
| 360 | 
            +
                            A current instance of a sample created by the diffusion process.
         | 
| 361 | 
            +
                    Returns:
         | 
| 362 | 
            +
                        `torch.Tensor`:
         | 
| 363 | 
            +
                            The converted model output.
         | 
| 364 | 
            +
                    """
         | 
| 365 | 
            +
                    timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
         | 
| 366 | 
            +
                    if sample is None:
         | 
| 367 | 
            +
                        if len(args) > 1:
         | 
| 368 | 
            +
                            sample = args[1]
         | 
| 369 | 
            +
                        else:
         | 
| 370 | 
            +
                            raise ValueError(
         | 
| 371 | 
            +
                                "missing `sample` as a required keyward argument")
         | 
| 372 | 
            +
                    if timestep is not None:
         | 
| 373 | 
            +
                        deprecate(
         | 
| 374 | 
            +
                            "timesteps",
         | 
| 375 | 
            +
                            "1.0.0",
         | 
| 376 | 
            +
                            "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
         | 
| 377 | 
            +
                        )
         | 
| 378 | 
            +
             | 
| 379 | 
            +
                    # DPM-Solver++ needs to solve an integral of the data prediction model.
         | 
| 380 | 
            +
                    if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
         | 
| 381 | 
            +
                        if self.config.prediction_type == "flow_prediction":
         | 
| 382 | 
            +
                            sigma_t = self.sigmas[self.step_index]
         | 
| 383 | 
            +
                            x0_pred = sample - sigma_t * model_output
         | 
| 384 | 
            +
                        else:
         | 
| 385 | 
            +
                            raise ValueError(
         | 
| 386 | 
            +
                                f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
         | 
| 387 | 
            +
                                " `v_prediction`, or `flow_prediction` for the FlowDPMSolverMultistepScheduler."
         | 
| 388 | 
            +
                            )
         | 
| 389 | 
            +
             | 
| 390 | 
            +
                        if self.config.thresholding:
         | 
| 391 | 
            +
                            x0_pred = self._threshold_sample(x0_pred)
         | 
| 392 | 
            +
             | 
| 393 | 
            +
                        return x0_pred
         | 
| 394 | 
            +
             | 
| 395 | 
            +
                    # DPM-Solver needs to solve an integral of the noise prediction model.
         | 
| 396 | 
            +
                    elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
         | 
| 397 | 
            +
                        if self.config.prediction_type == "flow_prediction":
         | 
| 398 | 
            +
                            sigma_t = self.sigmas[self.step_index]
         | 
| 399 | 
            +
                            epsilon = sample - (1 - sigma_t) * model_output
         | 
| 400 | 
            +
                        else:
         | 
| 401 | 
            +
                            raise ValueError(
         | 
| 402 | 
            +
                                f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
         | 
| 403 | 
            +
                                " `v_prediction` or `flow_prediction` for the FlowDPMSolverMultistepScheduler."
         | 
| 404 | 
            +
                            )
         | 
| 405 | 
            +
             | 
| 406 | 
            +
                        if self.config.thresholding:
         | 
| 407 | 
            +
                            sigma_t = self.sigmas[self.step_index]
         | 
| 408 | 
            +
                            x0_pred = sample - sigma_t * model_output
         | 
| 409 | 
            +
                            x0_pred = self._threshold_sample(x0_pred)
         | 
| 410 | 
            +
                            epsilon = model_output + x0_pred
         | 
| 411 | 
            +
             | 
| 412 | 
            +
                        return epsilon
         | 
| 413 | 
            +
             | 
| 414 | 
            +
                # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.dpm_solver_first_order_update
         | 
| 415 | 
            +
                def dpm_solver_first_order_update(
         | 
| 416 | 
            +
                    self,
         | 
| 417 | 
            +
                    model_output: torch.Tensor,
         | 
| 418 | 
            +
                    *args,
         | 
| 419 | 
            +
                    sample: torch.Tensor = None,
         | 
| 420 | 
            +
                    noise: Optional[torch.Tensor] = None,
         | 
| 421 | 
            +
                    **kwargs,
         | 
| 422 | 
            +
                ) -> torch.Tensor:
         | 
| 423 | 
            +
                    """
         | 
| 424 | 
            +
                    One step for the first-order DPMSolver (equivalent to DDIM).
         | 
| 425 | 
            +
                    Args:
         | 
| 426 | 
            +
                        model_output (`torch.Tensor`):
         | 
| 427 | 
            +
                            The direct output from the learned diffusion model.
         | 
| 428 | 
            +
                        sample (`torch.Tensor`):
         | 
| 429 | 
            +
                            A current instance of a sample created by the diffusion process.
         | 
| 430 | 
            +
                    Returns:
         | 
| 431 | 
            +
                        `torch.Tensor`:
         | 
| 432 | 
            +
                            The sample tensor at the previous timestep.
         | 
| 433 | 
            +
                    """
         | 
| 434 | 
            +
                    timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
         | 
| 435 | 
            +
                    prev_timestep = args[1] if len(args) > 1 else kwargs.pop(
         | 
| 436 | 
            +
                        "prev_timestep", None)
         | 
| 437 | 
            +
                    if sample is None:
         | 
| 438 | 
            +
                        if len(args) > 2:
         | 
| 439 | 
            +
                            sample = args[2]
         | 
| 440 | 
            +
                        else:
         | 
| 441 | 
            +
                            raise ValueError(
         | 
| 442 | 
            +
                                " missing `sample` as a required keyward argument")
         | 
| 443 | 
            +
                    if timestep is not None:
         | 
| 444 | 
            +
                        deprecate(
         | 
| 445 | 
            +
                            "timesteps",
         | 
| 446 | 
            +
                            "1.0.0",
         | 
| 447 | 
            +
                            "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
         | 
| 448 | 
            +
                        )
         | 
| 449 | 
            +
             | 
| 450 | 
            +
                    if prev_timestep is not None:
         | 
| 451 | 
            +
                        deprecate(
         | 
| 452 | 
            +
                            "prev_timestep",
         | 
| 453 | 
            +
                            "1.0.0",
         | 
| 454 | 
            +
                            "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
         | 
| 455 | 
            +
                        )
         | 
| 456 | 
            +
             | 
| 457 | 
            +
                    sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[
         | 
| 458 | 
            +
                        self.step_index]  # pyright: ignore
         | 
| 459 | 
            +
                    alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
         | 
| 460 | 
            +
                    alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
         | 
| 461 | 
            +
                    lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
         | 
| 462 | 
            +
                    lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
         | 
| 463 | 
            +
             | 
| 464 | 
            +
                    h = lambda_t - lambda_s
         | 
| 465 | 
            +
                    if self.config.algorithm_type == "dpmsolver++":
         | 
| 466 | 
            +
                        x_t = (sigma_t /
         | 
| 467 | 
            +
                               sigma_s) * sample - (alpha_t *
         | 
| 468 | 
            +
                                                    (torch.exp(-h) - 1.0)) * model_output
         | 
| 469 | 
            +
                    elif self.config.algorithm_type == "dpmsolver":
         | 
| 470 | 
            +
                        x_t = (alpha_t /
         | 
| 471 | 
            +
                               alpha_s) * sample - (sigma_t *
         | 
| 472 | 
            +
                                                    (torch.exp(h) - 1.0)) * model_output
         | 
| 473 | 
            +
                    elif self.config.algorithm_type == "sde-dpmsolver++":
         | 
| 474 | 
            +
                        assert noise is not None
         | 
| 475 | 
            +
                        x_t = ((sigma_t / sigma_s * torch.exp(-h)) * sample +
         | 
| 476 | 
            +
                               (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output +
         | 
| 477 | 
            +
                               sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise)
         | 
| 478 | 
            +
                    elif self.config.algorithm_type == "sde-dpmsolver":
         | 
| 479 | 
            +
                        assert noise is not None
         | 
| 480 | 
            +
                        x_t = ((alpha_t / alpha_s) * sample - 2.0 *
         | 
| 481 | 
            +
                               (sigma_t * (torch.exp(h) - 1.0)) * model_output +
         | 
| 482 | 
            +
                               sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise)
         | 
| 483 | 
            +
                    return x_t  # pyright: ignore
         | 
| 484 | 
            +
             | 
| 485 | 
            +
                # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_second_order_update
         | 
| 486 | 
            +
                def multistep_dpm_solver_second_order_update(
         | 
| 487 | 
            +
                    self,
         | 
| 488 | 
            +
                    model_output_list: List[torch.Tensor],
         | 
| 489 | 
            +
                    *args,
         | 
| 490 | 
            +
                    sample: torch.Tensor = None,
         | 
| 491 | 
            +
                    noise: Optional[torch.Tensor] = None,
         | 
| 492 | 
            +
                    **kwargs,
         | 
| 493 | 
            +
                ) -> torch.Tensor:
         | 
| 494 | 
            +
                    """
         | 
| 495 | 
            +
                    One step for the second-order multistep DPMSolver.
         | 
| 496 | 
            +
                    Args:
         | 
| 497 | 
            +
                        model_output_list (`List[torch.Tensor]`):
         | 
| 498 | 
            +
                            The direct outputs from learned diffusion model at current and latter timesteps.
         | 
| 499 | 
            +
                        sample (`torch.Tensor`):
         | 
| 500 | 
            +
                            A current instance of a sample created by the diffusion process.
         | 
| 501 | 
            +
                    Returns:
         | 
| 502 | 
            +
                        `torch.Tensor`:
         | 
| 503 | 
            +
                            The sample tensor at the previous timestep.
         | 
| 504 | 
            +
                    """
         | 
| 505 | 
            +
                    timestep_list = args[0] if len(args) > 0 else kwargs.pop(
         | 
| 506 | 
            +
                        "timestep_list", None)
         | 
| 507 | 
            +
                    prev_timestep = args[1] if len(args) > 1 else kwargs.pop(
         | 
| 508 | 
            +
                        "prev_timestep", None)
         | 
| 509 | 
            +
                    if sample is None:
         | 
| 510 | 
            +
                        if len(args) > 2:
         | 
| 511 | 
            +
                            sample = args[2]
         | 
| 512 | 
            +
                        else:
         | 
| 513 | 
            +
                            raise ValueError(
         | 
| 514 | 
            +
                                " missing `sample` as a required keyward argument")
         | 
| 515 | 
            +
                    if timestep_list is not None:
         | 
| 516 | 
            +
                        deprecate(
         | 
| 517 | 
            +
                            "timestep_list",
         | 
| 518 | 
            +
                            "1.0.0",
         | 
| 519 | 
            +
                            "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
         | 
| 520 | 
            +
                        )
         | 
| 521 | 
            +
             | 
| 522 | 
            +
                    if prev_timestep is not None:
         | 
| 523 | 
            +
                        deprecate(
         | 
| 524 | 
            +
                            "prev_timestep",
         | 
| 525 | 
            +
                            "1.0.0",
         | 
| 526 | 
            +
                            "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
         | 
| 527 | 
            +
                        )
         | 
| 528 | 
            +
             | 
| 529 | 
            +
                    sigma_t, sigma_s0, sigma_s1 = (
         | 
| 530 | 
            +
                        self.sigmas[self.step_index + 1],  # pyright: ignore
         | 
| 531 | 
            +
                        self.sigmas[self.step_index],
         | 
| 532 | 
            +
                        self.sigmas[self.step_index - 1],  # pyright: ignore
         | 
| 533 | 
            +
                    )
         | 
| 534 | 
            +
             | 
| 535 | 
            +
                    alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
         | 
| 536 | 
            +
                    alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
         | 
| 537 | 
            +
                    alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
         | 
| 538 | 
            +
             | 
| 539 | 
            +
                    lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
         | 
| 540 | 
            +
                    lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
         | 
| 541 | 
            +
                    lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
         | 
| 542 | 
            +
             | 
| 543 | 
            +
                    m0, m1 = model_output_list[-1], model_output_list[-2]
         | 
| 544 | 
            +
             | 
| 545 | 
            +
                    h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
         | 
| 546 | 
            +
                    r0 = h_0 / h
         | 
| 547 | 
            +
                    D0, D1 = m0, (1.0 / r0) * (m0 - m1)
         | 
| 548 | 
            +
                    if self.config.algorithm_type == "dpmsolver++":
         | 
| 549 | 
            +
                        # See https://arxiv.org/abs/2211.01095 for detailed derivations
         | 
| 550 | 
            +
                        if self.config.solver_type == "midpoint":
         | 
| 551 | 
            +
                            x_t = ((sigma_t / sigma_s0) * sample -
         | 
| 552 | 
            +
                                   (alpha_t * (torch.exp(-h) - 1.0)) * D0 - 0.5 *
         | 
| 553 | 
            +
                                   (alpha_t * (torch.exp(-h) - 1.0)) * D1)
         | 
| 554 | 
            +
                        elif self.config.solver_type == "heun":
         | 
| 555 | 
            +
                            x_t = ((sigma_t / sigma_s0) * sample -
         | 
| 556 | 
            +
                                   (alpha_t * (torch.exp(-h) - 1.0)) * D0 +
         | 
| 557 | 
            +
                                   (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1)
         | 
| 558 | 
            +
                    elif self.config.algorithm_type == "dpmsolver":
         | 
| 559 | 
            +
                        # See https://arxiv.org/abs/2206.00927 for detailed derivations
         | 
| 560 | 
            +
                        if self.config.solver_type == "midpoint":
         | 
| 561 | 
            +
                            x_t = ((alpha_t / alpha_s0) * sample -
         | 
| 562 | 
            +
                                   (sigma_t * (torch.exp(h) - 1.0)) * D0 - 0.5 *
         | 
| 563 | 
            +
                                   (sigma_t * (torch.exp(h) - 1.0)) * D1)
         | 
| 564 | 
            +
                        elif self.config.solver_type == "heun":
         | 
| 565 | 
            +
                            x_t = ((alpha_t / alpha_s0) * sample -
         | 
| 566 | 
            +
                                   (sigma_t * (torch.exp(h) - 1.0)) * D0 -
         | 
| 567 | 
            +
                                   (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1)
         | 
| 568 | 
            +
                    elif self.config.algorithm_type == "sde-dpmsolver++":
         | 
| 569 | 
            +
                        assert noise is not None
         | 
| 570 | 
            +
                        if self.config.solver_type == "midpoint":
         | 
| 571 | 
            +
                            x_t = ((sigma_t / sigma_s0 * torch.exp(-h)) * sample +
         | 
| 572 | 
            +
                                   (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + 0.5 *
         | 
| 573 | 
            +
                                   (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 +
         | 
| 574 | 
            +
                                   sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise)
         | 
| 575 | 
            +
                        elif self.config.solver_type == "heun":
         | 
| 576 | 
            +
                            x_t = ((sigma_t / sigma_s0 * torch.exp(-h)) * sample +
         | 
| 577 | 
            +
                                   (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 +
         | 
| 578 | 
            +
                                   (alpha_t * ((1.0 - torch.exp(-2.0 * h)) /
         | 
| 579 | 
            +
                                               (-2.0 * h) + 1.0)) * D1 +
         | 
| 580 | 
            +
                                   sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise)
         | 
| 581 | 
            +
                    elif self.config.algorithm_type == "sde-dpmsolver":
         | 
| 582 | 
            +
                        assert noise is not None
         | 
| 583 | 
            +
                        if self.config.solver_type == "midpoint":
         | 
| 584 | 
            +
                            x_t = ((alpha_t / alpha_s0) * sample - 2.0 *
         | 
| 585 | 
            +
                                   (sigma_t * (torch.exp(h) - 1.0)) * D0 -
         | 
| 586 | 
            +
                                   (sigma_t * (torch.exp(h) - 1.0)) * D1 +
         | 
| 587 | 
            +
                                   sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise)
         | 
| 588 | 
            +
                        elif self.config.solver_type == "heun":
         | 
| 589 | 
            +
                            x_t = ((alpha_t / alpha_s0) * sample - 2.0 *
         | 
| 590 | 
            +
                                   (sigma_t * (torch.exp(h) - 1.0)) * D0 - 2.0 *
         | 
| 591 | 
            +
                                   (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 +
         | 
| 592 | 
            +
                                   sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise)
         | 
| 593 | 
            +
                    return x_t  # pyright: ignore
         | 
| 594 | 
            +
             | 
| 595 | 
            +
                # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_third_order_update
         | 
| 596 | 
            +
                def multistep_dpm_solver_third_order_update(
         | 
| 597 | 
            +
                    self,
         | 
| 598 | 
            +
                    model_output_list: List[torch.Tensor],
         | 
| 599 | 
            +
                    *args,
         | 
| 600 | 
            +
                    sample: torch.Tensor = None,
         | 
| 601 | 
            +
                    **kwargs,
         | 
| 602 | 
            +
                ) -> torch.Tensor:
         | 
| 603 | 
            +
                    """
         | 
| 604 | 
            +
                    One step for the third-order multistep DPMSolver.
         | 
| 605 | 
            +
                    Args:
         | 
| 606 | 
            +
                        model_output_list (`List[torch.Tensor]`):
         | 
| 607 | 
            +
                            The direct outputs from learned diffusion model at current and latter timesteps.
         | 
| 608 | 
            +
                        sample (`torch.Tensor`):
         | 
| 609 | 
            +
                            A current instance of a sample created by diffusion process.
         | 
| 610 | 
            +
                    Returns:
         | 
| 611 | 
            +
                        `torch.Tensor`:
         | 
| 612 | 
            +
                            The sample tensor at the previous timestep.
         | 
| 613 | 
            +
                    """
         | 
| 614 | 
            +
             | 
| 615 | 
            +
                    timestep_list = args[0] if len(args) > 0 else kwargs.pop(
         | 
| 616 | 
            +
                        "timestep_list", None)
         | 
| 617 | 
            +
                    prev_timestep = args[1] if len(args) > 1 else kwargs.pop(
         | 
| 618 | 
            +
                        "prev_timestep", None)
         | 
| 619 | 
            +
                    if sample is None:
         | 
| 620 | 
            +
                        if len(args) > 2:
         | 
| 621 | 
            +
                            sample = args[2]
         | 
| 622 | 
            +
                        else:
         | 
| 623 | 
            +
                            raise ValueError(
         | 
| 624 | 
            +
                                " missing`sample` as a required keyward argument")
         | 
| 625 | 
            +
                    if timestep_list is not None:
         | 
| 626 | 
            +
                        deprecate(
         | 
| 627 | 
            +
                            "timestep_list",
         | 
| 628 | 
            +
                            "1.0.0",
         | 
| 629 | 
            +
                            "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
         | 
| 630 | 
            +
                        )
         | 
| 631 | 
            +
             | 
| 632 | 
            +
                    if prev_timestep is not None:
         | 
| 633 | 
            +
                        deprecate(
         | 
| 634 | 
            +
                            "prev_timestep",
         | 
| 635 | 
            +
                            "1.0.0",
         | 
| 636 | 
            +
                            "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
         | 
| 637 | 
            +
                        )
         | 
| 638 | 
            +
             | 
| 639 | 
            +
                    sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
         | 
| 640 | 
            +
                        self.sigmas[self.step_index + 1],  # pyright: ignore
         | 
| 641 | 
            +
                        self.sigmas[self.step_index],
         | 
| 642 | 
            +
                        self.sigmas[self.step_index - 1],  # pyright: ignore
         | 
| 643 | 
            +
                        self.sigmas[self.step_index - 2],  # pyright: ignore
         | 
| 644 | 
            +
                    )
         | 
| 645 | 
            +
             | 
| 646 | 
            +
                    alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
         | 
| 647 | 
            +
                    alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
         | 
| 648 | 
            +
                    alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
         | 
| 649 | 
            +
                    alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
         | 
| 650 | 
            +
             | 
| 651 | 
            +
                    lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
         | 
| 652 | 
            +
                    lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
         | 
| 653 | 
            +
                    lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
         | 
| 654 | 
            +
                    lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2)
         | 
| 655 | 
            +
             | 
| 656 | 
            +
                    m0, m1, m2 = model_output_list[-1], model_output_list[
         | 
| 657 | 
            +
                        -2], model_output_list[-3]
         | 
| 658 | 
            +
             | 
| 659 | 
            +
                    h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
         | 
| 660 | 
            +
                    r0, r1 = h_0 / h, h_1 / h
         | 
| 661 | 
            +
                    D0 = m0
         | 
| 662 | 
            +
                    D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)
         | 
| 663 | 
            +
                    D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
         | 
| 664 | 
            +
                    D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
         | 
| 665 | 
            +
                    if self.config.algorithm_type == "dpmsolver++":
         | 
| 666 | 
            +
                        # See https://arxiv.org/abs/2206.00927 for detailed derivations
         | 
| 667 | 
            +
                        x_t = ((sigma_t / sigma_s0) * sample -
         | 
| 668 | 
            +
                               (alpha_t * (torch.exp(-h) - 1.0)) * D0 +
         | 
| 669 | 
            +
                               (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 -
         | 
| 670 | 
            +
                               (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2)
         | 
| 671 | 
            +
                    elif self.config.algorithm_type == "dpmsolver":
         | 
| 672 | 
            +
                        # See https://arxiv.org/abs/2206.00927 for detailed derivations
         | 
| 673 | 
            +
                        x_t = ((alpha_t / alpha_s0) * sample - (sigma_t *
         | 
| 674 | 
            +
                                                                (torch.exp(h) - 1.0)) * D0 -
         | 
| 675 | 
            +
                               (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 -
         | 
| 676 | 
            +
                               (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2)
         | 
| 677 | 
            +
                    return x_t  # pyright: ignore
         | 
| 678 | 
            +
             | 
| 679 | 
            +
                def index_for_timestep(self, timestep, schedule_timesteps=None):
         | 
| 680 | 
            +
                    if schedule_timesteps is None:
         | 
| 681 | 
            +
                        schedule_timesteps = self.timesteps
         | 
| 682 | 
            +
             | 
| 683 | 
            +
                    indices = (schedule_timesteps == timestep).nonzero()
         | 
| 684 | 
            +
             | 
| 685 | 
            +
                    # The sigma index that is taken for the **very** first `step`
         | 
| 686 | 
            +
                    # is always the second index (or the last index if there is only 1)
         | 
| 687 | 
            +
                    # This way we can ensure we don't accidentally skip a sigma in
         | 
| 688 | 
            +
                    # case we start in the middle of the denoising schedule (e.g. for image-to-image)
         | 
| 689 | 
            +
                    pos = 1 if len(indices) > 1 else 0
         | 
| 690 | 
            +
             | 
| 691 | 
            +
                    return indices[pos].item()
         | 
| 692 | 
            +
             | 
| 693 | 
            +
                def _init_step_index(self, timestep):
         | 
| 694 | 
            +
                    """
         | 
| 695 | 
            +
                    Initialize the step_index counter for the scheduler.
         | 
| 696 | 
            +
                    """
         | 
| 697 | 
            +
             | 
| 698 | 
            +
                    if self.begin_index is None:
         | 
| 699 | 
            +
                        if isinstance(timestep, torch.Tensor):
         | 
| 700 | 
            +
                            timestep = timestep.to(self.timesteps.device)
         | 
| 701 | 
            +
                        self._step_index = self.index_for_timestep(timestep)
         | 
| 702 | 
            +
                    else:
         | 
| 703 | 
            +
                        self._step_index = self._begin_index
         | 
| 704 | 
            +
             | 
| 705 | 
            +
                # Modified from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.step
         | 
| 706 | 
            +
                def step(
         | 
| 707 | 
            +
                    self,
         | 
| 708 | 
            +
                    model_output: torch.Tensor,
         | 
| 709 | 
            +
                    timestep: Union[int, torch.Tensor],
         | 
| 710 | 
            +
                    sample: torch.Tensor,
         | 
| 711 | 
            +
                    generator=None,
         | 
| 712 | 
            +
                    variance_noise: Optional[torch.Tensor] = None,
         | 
| 713 | 
            +
                    return_dict: bool = True,
         | 
| 714 | 
            +
                ) -> Union[SchedulerOutput, Tuple]:
         | 
| 715 | 
            +
                    """
         | 
| 716 | 
            +
                    Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
         | 
| 717 | 
            +
                    the multistep DPMSolver.
         | 
| 718 | 
            +
                    Args:
         | 
| 719 | 
            +
                        model_output (`torch.Tensor`):
         | 
| 720 | 
            +
                            The direct output from learned diffusion model.
         | 
| 721 | 
            +
                        timestep (`int`):
         | 
| 722 | 
            +
                            The current discrete timestep in the diffusion chain.
         | 
| 723 | 
            +
                        sample (`torch.Tensor`):
         | 
| 724 | 
            +
                            A current instance of a sample created by the diffusion process.
         | 
| 725 | 
            +
                        generator (`torch.Generator`, *optional*):
         | 
| 726 | 
            +
                            A random number generator.
         | 
| 727 | 
            +
                        variance_noise (`torch.Tensor`):
         | 
| 728 | 
            +
                            Alternative to generating noise with `generator` by directly providing the noise for the variance
         | 
| 729 | 
            +
                            itself. Useful for methods such as [`LEdits++`].
         | 
| 730 | 
            +
                        return_dict (`bool`):
         | 
| 731 | 
            +
                            Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
         | 
| 732 | 
            +
                    Returns:
         | 
| 733 | 
            +
                        [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
         | 
| 734 | 
            +
                            If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
         | 
| 735 | 
            +
                            tuple is returned where the first element is the sample tensor.
         | 
| 736 | 
            +
                    """
         | 
| 737 | 
            +
                    if self.num_inference_steps is None:
         | 
| 738 | 
            +
                        raise ValueError(
         | 
| 739 | 
            +
                            "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
         | 
| 740 | 
            +
                        )
         | 
| 741 | 
            +
             | 
| 742 | 
            +
                    if self.step_index is None:
         | 
| 743 | 
            +
                        self._init_step_index(timestep)
         | 
| 744 | 
            +
             | 
| 745 | 
            +
                    # Improve numerical stability for small number of steps
         | 
| 746 | 
            +
                    lower_order_final = (self.step_index == len(self.timesteps) - 1) and (
         | 
| 747 | 
            +
                        self.config.euler_at_final or
         | 
| 748 | 
            +
                        (self.config.lower_order_final and len(self.timesteps) < 15) or
         | 
| 749 | 
            +
                        self.config.final_sigmas_type == "zero")
         | 
| 750 | 
            +
                    lower_order_second = ((self.step_index == len(self.timesteps) - 2) and
         | 
| 751 | 
            +
                                          self.config.lower_order_final and
         | 
| 752 | 
            +
                                          len(self.timesteps) < 15)
         | 
| 753 | 
            +
             | 
| 754 | 
            +
                    model_output = self.convert_model_output(model_output, sample=sample)
         | 
| 755 | 
            +
                    for i in range(self.config.solver_order - 1):
         | 
| 756 | 
            +
                        self.model_outputs[i] = self.model_outputs[i + 1]
         | 
| 757 | 
            +
                    self.model_outputs[-1] = model_output
         | 
| 758 | 
            +
             | 
| 759 | 
            +
                    # Upcast to avoid precision issues when computing prev_sample
         | 
| 760 | 
            +
                    sample = sample.to(torch.float32)
         | 
| 761 | 
            +
                    if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"
         | 
| 762 | 
            +
                                                     ] and variance_noise is None:
         | 
| 763 | 
            +
                        noise = randn_tensor(
         | 
| 764 | 
            +
                            model_output.shape,
         | 
| 765 | 
            +
                            generator=generator,
         | 
| 766 | 
            +
                            device=model_output.device,
         | 
| 767 | 
            +
                            dtype=torch.float32)
         | 
| 768 | 
            +
                    elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
         | 
| 769 | 
            +
                        noise = variance_noise.to(
         | 
| 770 | 
            +
                            device=model_output.device,
         | 
| 771 | 
            +
                            dtype=torch.float32)  # pyright: ignore
         | 
| 772 | 
            +
                    else:
         | 
| 773 | 
            +
                        noise = None
         | 
| 774 | 
            +
             | 
| 775 | 
            +
                    if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
         | 
| 776 | 
            +
                        prev_sample = self.dpm_solver_first_order_update(
         | 
| 777 | 
            +
                            model_output, sample=sample, noise=noise)
         | 
| 778 | 
            +
                    elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
         | 
| 779 | 
            +
                        prev_sample = self.multistep_dpm_solver_second_order_update(
         | 
| 780 | 
            +
                            self.model_outputs, sample=sample, noise=noise)
         | 
| 781 | 
            +
                    else:
         | 
| 782 | 
            +
                        prev_sample = self.multistep_dpm_solver_third_order_update(
         | 
| 783 | 
            +
                            self.model_outputs, sample=sample)
         | 
| 784 | 
            +
             | 
| 785 | 
            +
                    if self.lower_order_nums < self.config.solver_order:
         | 
| 786 | 
            +
                        self.lower_order_nums += 1
         | 
| 787 | 
            +
             | 
| 788 | 
            +
                    # Cast sample back to expected dtype
         | 
| 789 | 
            +
                    prev_sample = prev_sample.to(model_output.dtype)
         | 
| 790 | 
            +
             | 
| 791 | 
            +
                    # upon completion increase step index by one
         | 
| 792 | 
            +
                    self._step_index += 1  # pyright: ignore
         | 
| 793 | 
            +
             | 
| 794 | 
            +
                    if not return_dict:
         | 
| 795 | 
            +
                        return (prev_sample,)
         | 
| 796 | 
            +
             | 
| 797 | 
            +
                    return SchedulerOutput(prev_sample=prev_sample)
         | 
| 798 | 
            +
             | 
| 799 | 
            +
                # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input
         | 
| 800 | 
            +
                def scale_model_input(self, sample: torch.Tensor, *args,
         | 
| 801 | 
            +
                                      **kwargs) -> torch.Tensor:
         | 
| 802 | 
            +
                    """
         | 
| 803 | 
            +
                    Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
         | 
| 804 | 
            +
                    current timestep.
         | 
| 805 | 
            +
                    Args:
         | 
| 806 | 
            +
                        sample (`torch.Tensor`):
         | 
| 807 | 
            +
                            The input sample.
         | 
| 808 | 
            +
                    Returns:
         | 
| 809 | 
            +
                        `torch.Tensor`:
         | 
| 810 | 
            +
                            A scaled input sample.
         | 
| 811 | 
            +
                    """
         | 
| 812 | 
            +
                    return sample
         | 
| 813 | 
            +
             | 
| 814 | 
            +
                # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input
         | 
| 815 | 
            +
                def add_noise(
         | 
| 816 | 
            +
                    self,
         | 
| 817 | 
            +
                    original_samples: torch.Tensor,
         | 
| 818 | 
            +
                    noise: torch.Tensor,
         | 
| 819 | 
            +
                    timesteps: torch.IntTensor,
         | 
| 820 | 
            +
                ) -> torch.Tensor:
         | 
| 821 | 
            +
                    # Make sure sigmas and timesteps have the same device and dtype as original_samples
         | 
| 822 | 
            +
                    sigmas = self.sigmas.to(
         | 
| 823 | 
            +
                        device=original_samples.device, dtype=original_samples.dtype)
         | 
| 824 | 
            +
                    if original_samples.device.type == "mps" and torch.is_floating_point(
         | 
| 825 | 
            +
                            timesteps):
         | 
| 826 | 
            +
                        # mps does not support float64
         | 
| 827 | 
            +
                        schedule_timesteps = self.timesteps.to(
         | 
| 828 | 
            +
                            original_samples.device, dtype=torch.float32)
         | 
| 829 | 
            +
                        timesteps = timesteps.to(
         | 
| 830 | 
            +
                            original_samples.device, dtype=torch.float32)
         | 
| 831 | 
            +
                    else:
         | 
| 832 | 
            +
                        schedule_timesteps = self.timesteps.to(original_samples.device)
         | 
| 833 | 
            +
                        timesteps = timesteps.to(original_samples.device)
         | 
| 834 | 
            +
             | 
| 835 | 
            +
                    # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
         | 
| 836 | 
            +
                    if self.begin_index is None:
         | 
| 837 | 
            +
                        step_indices = [
         | 
| 838 | 
            +
                            self.index_for_timestep(t, schedule_timesteps)
         | 
| 839 | 
            +
                            for t in timesteps
         | 
| 840 | 
            +
                        ]
         | 
| 841 | 
            +
                    elif self.step_index is not None:
         | 
| 842 | 
            +
                        # add_noise is called after first denoising step (for inpainting)
         | 
| 843 | 
            +
                        step_indices = [self.step_index] * timesteps.shape[0]
         | 
| 844 | 
            +
                    else:
         | 
| 845 | 
            +
                        # add noise is called before first denoising step to create initial latent(img2img)
         | 
| 846 | 
            +
                        step_indices = [self.begin_index] * timesteps.shape[0]
         | 
| 847 | 
            +
             | 
| 848 | 
            +
                    sigma = sigmas[step_indices].flatten()
         | 
| 849 | 
            +
                    while len(sigma.shape) < len(original_samples.shape):
         | 
| 850 | 
            +
                        sigma = sigma.unsqueeze(-1)
         | 
| 851 | 
            +
             | 
| 852 | 
            +
                    alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
         | 
| 853 | 
            +
                    noisy_samples = alpha_t * original_samples + sigma_t * noise
         | 
| 854 | 
            +
                    return noisy_samples
         | 
| 855 | 
            +
             | 
| 856 | 
            +
                def __len__(self):
         | 
| 857 | 
            +
                    return self.config.num_train_timesteps
         | 
    	
        humo/models/utils/fm_solvers_unipc.py
    ADDED
    
    | @@ -0,0 +1,800 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py
         | 
| 2 | 
            +
            # Convert unipc for flow matching
         | 
| 3 | 
            +
            # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import math
         | 
| 6 | 
            +
            from typing import List, Optional, Tuple, Union
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            import numpy as np
         | 
| 9 | 
            +
            import torch
         | 
| 10 | 
            +
            from diffusers.configuration_utils import ConfigMixin, register_to_config
         | 
| 11 | 
            +
            from diffusers.schedulers.scheduling_utils import (KarrasDiffusionSchedulers,
         | 
| 12 | 
            +
                                                               SchedulerMixin,
         | 
| 13 | 
            +
                                                               SchedulerOutput)
         | 
| 14 | 
            +
            from diffusers.utils import deprecate, is_scipy_available
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            if is_scipy_available():
         | 
| 17 | 
            +
                import scipy.stats
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
         | 
| 21 | 
            +
                """
         | 
| 22 | 
            +
                `UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models.
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
         | 
| 25 | 
            +
                methods the library implements for all schedulers such as loading and saving.
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                Args:
         | 
| 28 | 
            +
                    num_train_timesteps (`int`, defaults to 1000):
         | 
| 29 | 
            +
                        The number of diffusion steps to train the model.
         | 
| 30 | 
            +
                    solver_order (`int`, default `2`):
         | 
| 31 | 
            +
                        The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1`
         | 
| 32 | 
            +
                        due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for
         | 
| 33 | 
            +
                        unconditional sampling.
         | 
| 34 | 
            +
                    prediction_type (`str`, defaults to "flow_prediction"):
         | 
| 35 | 
            +
                        Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts
         | 
| 36 | 
            +
                        the flow of the diffusion process.
         | 
| 37 | 
            +
                    thresholding (`bool`, defaults to `False`):
         | 
| 38 | 
            +
                        Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
         | 
| 39 | 
            +
                        as Stable Diffusion.
         | 
| 40 | 
            +
                    dynamic_thresholding_ratio (`float`, defaults to 0.995):
         | 
| 41 | 
            +
                        The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
         | 
| 42 | 
            +
                    sample_max_value (`float`, defaults to 1.0):
         | 
| 43 | 
            +
                        The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`.
         | 
| 44 | 
            +
                    predict_x0 (`bool`, defaults to `True`):
         | 
| 45 | 
            +
                        Whether to use the updating algorithm on the predicted x0.
         | 
| 46 | 
            +
                    solver_type (`str`, default `bh2`):
         | 
| 47 | 
            +
                        Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2`
         | 
| 48 | 
            +
                        otherwise.
         | 
| 49 | 
            +
                    lower_order_final (`bool`, default `True`):
         | 
| 50 | 
            +
                        Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
         | 
| 51 | 
            +
                        stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
         | 
| 52 | 
            +
                    disable_corrector (`list`, default `[]`):
         | 
| 53 | 
            +
                        Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)`
         | 
| 54 | 
            +
                        and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is
         | 
| 55 | 
            +
                        usually disabled during the first few steps.
         | 
| 56 | 
            +
                    solver_p (`SchedulerMixin`, default `None`):
         | 
| 57 | 
            +
                        Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`.
         | 
| 58 | 
            +
                    use_karras_sigmas (`bool`, *optional*, defaults to `False`):
         | 
| 59 | 
            +
                        Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
         | 
| 60 | 
            +
                        the sigmas are determined according to a sequence of noise levels {σi}.
         | 
| 61 | 
            +
                    use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
         | 
| 62 | 
            +
                        Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
         | 
| 63 | 
            +
                    timestep_spacing (`str`, defaults to `"linspace"`):
         | 
| 64 | 
            +
                        The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
         | 
| 65 | 
            +
                        Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
         | 
| 66 | 
            +
                    steps_offset (`int`, defaults to 0):
         | 
| 67 | 
            +
                        An offset added to the inference steps, as required by some model families.
         | 
| 68 | 
            +
                    final_sigmas_type (`str`, defaults to `"zero"`):
         | 
| 69 | 
            +
                        The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
         | 
| 70 | 
            +
                        sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
         | 
| 71 | 
            +
                """
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                _compatibles = [e.name for e in KarrasDiffusionSchedulers]
         | 
| 74 | 
            +
                order = 1
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                @register_to_config
         | 
| 77 | 
            +
                def __init__(
         | 
| 78 | 
            +
                        self,
         | 
| 79 | 
            +
                        num_train_timesteps: int = 1000,
         | 
| 80 | 
            +
                        solver_order: int = 2,
         | 
| 81 | 
            +
                        prediction_type: str = "flow_prediction",
         | 
| 82 | 
            +
                        shift: Optional[float] = 1.0,
         | 
| 83 | 
            +
                        use_dynamic_shifting=False,
         | 
| 84 | 
            +
                        thresholding: bool = False,
         | 
| 85 | 
            +
                        dynamic_thresholding_ratio: float = 0.995,
         | 
| 86 | 
            +
                        sample_max_value: float = 1.0,
         | 
| 87 | 
            +
                        predict_x0: bool = True,
         | 
| 88 | 
            +
                        solver_type: str = "bh2",
         | 
| 89 | 
            +
                        lower_order_final: bool = True,
         | 
| 90 | 
            +
                        disable_corrector: List[int] = [],
         | 
| 91 | 
            +
                        solver_p: SchedulerMixin = None,
         | 
| 92 | 
            +
                        timestep_spacing: str = "linspace",
         | 
| 93 | 
            +
                        steps_offset: int = 0,
         | 
| 94 | 
            +
                        final_sigmas_type: Optional[str] = "zero",  # "zero", "sigma_min"
         | 
| 95 | 
            +
                ):
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                    if solver_type not in ["bh1", "bh2"]:
         | 
| 98 | 
            +
                        if solver_type in ["midpoint", "heun", "logrho"]:
         | 
| 99 | 
            +
                            self.register_to_config(solver_type="bh2")
         | 
| 100 | 
            +
                        else:
         | 
| 101 | 
            +
                            raise NotImplementedError(
         | 
| 102 | 
            +
                                f"{solver_type} is not implemented for {self.__class__}")
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                    self.predict_x0 = predict_x0
         | 
| 105 | 
            +
                    # setable values
         | 
| 106 | 
            +
                    self.num_inference_steps = None
         | 
| 107 | 
            +
                    alphas = np.linspace(1, 1 / num_train_timesteps,
         | 
| 108 | 
            +
                                         num_train_timesteps)[::-1].copy()
         | 
| 109 | 
            +
                    sigmas = 1.0 - alphas
         | 
| 110 | 
            +
                    sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32)
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                    if not use_dynamic_shifting:
         | 
| 113 | 
            +
                        # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
         | 
| 114 | 
            +
                        sigmas = shift * sigmas / (1 +
         | 
| 115 | 
            +
                                                   (shift - 1) * sigmas)  # pyright: ignore
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                    self.sigmas = sigmas
         | 
| 118 | 
            +
                    self.timesteps = sigmas * num_train_timesteps
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                    self.model_outputs = [None] * solver_order
         | 
| 121 | 
            +
                    self.timestep_list = [None] * solver_order
         | 
| 122 | 
            +
                    self.lower_order_nums = 0
         | 
| 123 | 
            +
                    self.disable_corrector = disable_corrector
         | 
| 124 | 
            +
                    self.solver_p = solver_p
         | 
| 125 | 
            +
                    self.last_sample = None
         | 
| 126 | 
            +
                    self._step_index = None
         | 
| 127 | 
            +
                    self._begin_index = None
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                    self.sigmas = self.sigmas.to(
         | 
| 130 | 
            +
                        "cpu")  # to avoid too much CPU/GPU communication
         | 
| 131 | 
            +
                    self.sigma_min = self.sigmas[-1].item()
         | 
| 132 | 
            +
                    self.sigma_max = self.sigmas[0].item()
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                @property
         | 
| 135 | 
            +
                def step_index(self):
         | 
| 136 | 
            +
                    """
         | 
| 137 | 
            +
                    The index counter for current timestep. It will increase 1 after each scheduler step.
         | 
| 138 | 
            +
                    """
         | 
| 139 | 
            +
                    return self._step_index
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                @property
         | 
| 142 | 
            +
                def begin_index(self):
         | 
| 143 | 
            +
                    """
         | 
| 144 | 
            +
                    The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
         | 
| 145 | 
            +
                    """
         | 
| 146 | 
            +
                    return self._begin_index
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
         | 
| 149 | 
            +
                def set_begin_index(self, begin_index: int = 0):
         | 
| 150 | 
            +
                    """
         | 
| 151 | 
            +
                    Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                    Args:
         | 
| 154 | 
            +
                        begin_index (`int`):
         | 
| 155 | 
            +
                            The begin index for the scheduler.
         | 
| 156 | 
            +
                    """
         | 
| 157 | 
            +
                    self._begin_index = begin_index
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps
         | 
| 160 | 
            +
                def set_timesteps(
         | 
| 161 | 
            +
                    self,
         | 
| 162 | 
            +
                    num_inference_steps: Union[int, None] = None,
         | 
| 163 | 
            +
                    device: Union[str, torch.device] = None,
         | 
| 164 | 
            +
                    sigmas: Optional[List[float]] = None,
         | 
| 165 | 
            +
                    mu: Optional[Union[float, None]] = None,
         | 
| 166 | 
            +
                    shift: Optional[Union[float, None]] = None,
         | 
| 167 | 
            +
                ):
         | 
| 168 | 
            +
                    """
         | 
| 169 | 
            +
                    Sets the discrete timesteps used for the diffusion chain (to be run before inference).
         | 
| 170 | 
            +
                    Args:
         | 
| 171 | 
            +
                        num_inference_steps (`int`):
         | 
| 172 | 
            +
                            Total number of the spacing of the time steps.
         | 
| 173 | 
            +
                        device (`str` or `torch.device`, *optional*):
         | 
| 174 | 
            +
                            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
         | 
| 175 | 
            +
                    """
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                    if self.config.use_dynamic_shifting and mu is None:
         | 
| 178 | 
            +
                        raise ValueError(
         | 
| 179 | 
            +
                            " you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`"
         | 
| 180 | 
            +
                        )
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                    if sigmas is None:
         | 
| 183 | 
            +
                        sigmas = np.linspace(self.sigma_max, self.sigma_min,
         | 
| 184 | 
            +
                                             num_inference_steps +
         | 
| 185 | 
            +
                                             1).copy()[:-1]  # pyright: ignore
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                    if self.config.use_dynamic_shifting:
         | 
| 188 | 
            +
                        sigmas = self.time_shift(mu, 1.0, sigmas)  # pyright: ignore
         | 
| 189 | 
            +
                    else:
         | 
| 190 | 
            +
                        if shift is None:
         | 
| 191 | 
            +
                            shift = self.config.shift
         | 
| 192 | 
            +
                        sigmas = shift * sigmas / (1 +
         | 
| 193 | 
            +
                                                   (shift - 1) * sigmas)  # pyright: ignore
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                    if self.config.final_sigmas_type == "sigma_min":
         | 
| 196 | 
            +
                        sigma_last = ((1 - self.alphas_cumprod[0]) /
         | 
| 197 | 
            +
                                      self.alphas_cumprod[0])**0.5
         | 
| 198 | 
            +
                    elif self.config.final_sigmas_type == "zero":
         | 
| 199 | 
            +
                        sigma_last = 0
         | 
| 200 | 
            +
                    else:
         | 
| 201 | 
            +
                        raise ValueError(
         | 
| 202 | 
            +
                            f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
         | 
| 203 | 
            +
                        )
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                    timesteps = sigmas * self.config.num_train_timesteps
         | 
| 206 | 
            +
                    sigmas = np.concatenate([sigmas, [sigma_last]
         | 
| 207 | 
            +
                                            ]).astype(np.float32)  # pyright: ignore
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                    self.sigmas = torch.from_numpy(sigmas)
         | 
| 210 | 
            +
                    self.timesteps = torch.from_numpy(timesteps).to(
         | 
| 211 | 
            +
                        device=device, dtype=torch.int64)
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                    self.num_inference_steps = len(timesteps)
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                    self.model_outputs = [
         | 
| 216 | 
            +
                        None,
         | 
| 217 | 
            +
                    ] * self.config.solver_order
         | 
| 218 | 
            +
                    self.lower_order_nums = 0
         | 
| 219 | 
            +
                    self.last_sample = None
         | 
| 220 | 
            +
                    if self.solver_p:
         | 
| 221 | 
            +
                        self.solver_p.set_timesteps(self.num_inference_steps, device=device)
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                    # add an index counter for schedulers that allow duplicated timesteps
         | 
| 224 | 
            +
                    self._step_index = None
         | 
| 225 | 
            +
                    self._begin_index = None
         | 
| 226 | 
            +
                    self.sigmas = self.sigmas.to(
         | 
| 227 | 
            +
                        "cpu")  # to avoid too much CPU/GPU communication
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
         | 
| 230 | 
            +
                def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
         | 
| 231 | 
            +
                    """
         | 
| 232 | 
            +
                    "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
         | 
| 233 | 
            +
                    prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
         | 
| 234 | 
            +
                    s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
         | 
| 235 | 
            +
                    pixels from saturation at each step. We find that dynamic thresholding results in significantly better
         | 
| 236 | 
            +
                    photorealism as well as better image-text alignment, especially when using very large guidance weights."
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                    https://arxiv.org/abs/2205.11487
         | 
| 239 | 
            +
                    """
         | 
| 240 | 
            +
                    dtype = sample.dtype
         | 
| 241 | 
            +
                    batch_size, channels, *remaining_dims = sample.shape
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                    if dtype not in (torch.float32, torch.float64):
         | 
| 244 | 
            +
                        sample = sample.float(
         | 
| 245 | 
            +
                        )  # upcast for quantile calculation, and clamp not implemented for cpu half
         | 
| 246 | 
            +
             | 
| 247 | 
            +
                    # Flatten sample for doing quantile calculation along each image
         | 
| 248 | 
            +
                    sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                    abs_sample = sample.abs()  # "a certain percentile absolute pixel value"
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                    s = torch.quantile(
         | 
| 253 | 
            +
                        abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
         | 
| 254 | 
            +
                    s = torch.clamp(
         | 
| 255 | 
            +
                        s, min=1, max=self.config.sample_max_value
         | 
| 256 | 
            +
                    )  # When clamped to min=1, equivalent to standard clipping to [-1, 1]
         | 
| 257 | 
            +
                    s = s.unsqueeze(
         | 
| 258 | 
            +
                        1)  # (batch_size, 1) because clamp will broadcast along dim=0
         | 
| 259 | 
            +
                    sample = torch.clamp(
         | 
| 260 | 
            +
                        sample, -s, s
         | 
| 261 | 
            +
                    ) / s  # "we threshold xt0 to the range [-s, s] and then divide by s"
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                    sample = sample.reshape(batch_size, channels, *remaining_dims)
         | 
| 264 | 
            +
                    sample = sample.to(dtype)
         | 
| 265 | 
            +
             | 
| 266 | 
            +
                    return sample
         | 
| 267 | 
            +
             | 
| 268 | 
            +
                # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t
         | 
| 269 | 
            +
                def _sigma_to_t(self, sigma):
         | 
| 270 | 
            +
                    return sigma * self.config.num_train_timesteps
         | 
| 271 | 
            +
             | 
| 272 | 
            +
                def _sigma_to_alpha_sigma_t(self, sigma):
         | 
| 273 | 
            +
                    return 1 - sigma, sigma
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps
         | 
| 276 | 
            +
                def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
         | 
| 277 | 
            +
                    return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma)
         | 
| 278 | 
            +
             | 
| 279 | 
            +
                def convert_model_output(
         | 
| 280 | 
            +
                    self,
         | 
| 281 | 
            +
                    model_output: torch.Tensor,
         | 
| 282 | 
            +
                    *args,
         | 
| 283 | 
            +
                    sample: torch.Tensor = None,
         | 
| 284 | 
            +
                    **kwargs,
         | 
| 285 | 
            +
                ) -> torch.Tensor:
         | 
| 286 | 
            +
                    r"""
         | 
| 287 | 
            +
                    Convert the model output to the corresponding type the UniPC algorithm needs.
         | 
| 288 | 
            +
             | 
| 289 | 
            +
                    Args:
         | 
| 290 | 
            +
                        model_output (`torch.Tensor`):
         | 
| 291 | 
            +
                            The direct output from the learned diffusion model.
         | 
| 292 | 
            +
                        timestep (`int`):
         | 
| 293 | 
            +
                            The current discrete timestep in the diffusion chain.
         | 
| 294 | 
            +
                        sample (`torch.Tensor`):
         | 
| 295 | 
            +
                            A current instance of a sample created by the diffusion process.
         | 
| 296 | 
            +
             | 
| 297 | 
            +
                    Returns:
         | 
| 298 | 
            +
                        `torch.Tensor`:
         | 
| 299 | 
            +
                            The converted model output.
         | 
| 300 | 
            +
                    """
         | 
| 301 | 
            +
                    timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
         | 
| 302 | 
            +
                    if sample is None:
         | 
| 303 | 
            +
                        if len(args) > 1:
         | 
| 304 | 
            +
                            sample = args[1]
         | 
| 305 | 
            +
                        else:
         | 
| 306 | 
            +
                            raise ValueError(
         | 
| 307 | 
            +
                                "missing `sample` as a required keyward argument")
         | 
| 308 | 
            +
                    if timestep is not None:
         | 
| 309 | 
            +
                        deprecate(
         | 
| 310 | 
            +
                            "timesteps",
         | 
| 311 | 
            +
                            "1.0.0",
         | 
| 312 | 
            +
                            "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
         | 
| 313 | 
            +
                        )
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                    sigma = self.sigmas[self.step_index]
         | 
| 316 | 
            +
                    alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
         | 
| 317 | 
            +
             | 
| 318 | 
            +
                    if self.predict_x0:
         | 
| 319 | 
            +
                        if self.config.prediction_type == "flow_prediction":
         | 
| 320 | 
            +
                            sigma_t = self.sigmas[self.step_index]
         | 
| 321 | 
            +
                            x0_pred = sample - sigma_t * model_output
         | 
| 322 | 
            +
                        else:
         | 
| 323 | 
            +
                            raise ValueError(
         | 
| 324 | 
            +
                                f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
         | 
| 325 | 
            +
                                " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler."
         | 
| 326 | 
            +
                            )
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                        if self.config.thresholding:
         | 
| 329 | 
            +
                            x0_pred = self._threshold_sample(x0_pred)
         | 
| 330 | 
            +
             | 
| 331 | 
            +
                        return x0_pred
         | 
| 332 | 
            +
                    else:
         | 
| 333 | 
            +
                        if self.config.prediction_type == "flow_prediction":
         | 
| 334 | 
            +
                            sigma_t = self.sigmas[self.step_index]
         | 
| 335 | 
            +
                            epsilon = sample - (1 - sigma_t) * model_output
         | 
| 336 | 
            +
                        else:
         | 
| 337 | 
            +
                            raise ValueError(
         | 
| 338 | 
            +
                                f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
         | 
| 339 | 
            +
                                " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler."
         | 
| 340 | 
            +
                            )
         | 
| 341 | 
            +
             | 
| 342 | 
            +
                        if self.config.thresholding:
         | 
| 343 | 
            +
                            sigma_t = self.sigmas[self.step_index]
         | 
| 344 | 
            +
                            x0_pred = sample - sigma_t * model_output
         | 
| 345 | 
            +
                            x0_pred = self._threshold_sample(x0_pred)
         | 
| 346 | 
            +
                            epsilon = model_output + x0_pred
         | 
| 347 | 
            +
             | 
| 348 | 
            +
                        return epsilon
         | 
| 349 | 
            +
             | 
| 350 | 
            +
                def multistep_uni_p_bh_update(
         | 
| 351 | 
            +
                    self,
         | 
| 352 | 
            +
                    model_output: torch.Tensor,
         | 
| 353 | 
            +
                    *args,
         | 
| 354 | 
            +
                    sample: torch.Tensor = None,
         | 
| 355 | 
            +
                    order: int = None,  # pyright: ignore
         | 
| 356 | 
            +
                    **kwargs,
         | 
| 357 | 
            +
                ) -> torch.Tensor:
         | 
| 358 | 
            +
                    """
         | 
| 359 | 
            +
                    One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified.
         | 
| 360 | 
            +
             | 
| 361 | 
            +
                    Args:
         | 
| 362 | 
            +
                        model_output (`torch.Tensor`):
         | 
| 363 | 
            +
                            The direct output from the learned diffusion model at the current timestep.
         | 
| 364 | 
            +
                        prev_timestep (`int`):
         | 
| 365 | 
            +
                            The previous discrete timestep in the diffusion chain.
         | 
| 366 | 
            +
                        sample (`torch.Tensor`):
         | 
| 367 | 
            +
                            A current instance of a sample created by the diffusion process.
         | 
| 368 | 
            +
                        order (`int`):
         | 
| 369 | 
            +
                            The order of UniP at this timestep (corresponds to the *p* in UniPC-p).
         | 
| 370 | 
            +
             | 
| 371 | 
            +
                    Returns:
         | 
| 372 | 
            +
                        `torch.Tensor`:
         | 
| 373 | 
            +
                            The sample tensor at the previous timestep.
         | 
| 374 | 
            +
                    """
         | 
| 375 | 
            +
                    prev_timestep = args[0] if len(args) > 0 else kwargs.pop(
         | 
| 376 | 
            +
                        "prev_timestep", None)
         | 
| 377 | 
            +
                    if sample is None:
         | 
| 378 | 
            +
                        if len(args) > 1:
         | 
| 379 | 
            +
                            sample = args[1]
         | 
| 380 | 
            +
                        else:
         | 
| 381 | 
            +
                            raise ValueError(
         | 
| 382 | 
            +
                                " missing `sample` as a required keyward argument")
         | 
| 383 | 
            +
                    if order is None:
         | 
| 384 | 
            +
                        if len(args) > 2:
         | 
| 385 | 
            +
                            order = args[2]
         | 
| 386 | 
            +
                        else:
         | 
| 387 | 
            +
                            raise ValueError(
         | 
| 388 | 
            +
                                " missing `order` as a required keyward argument")
         | 
| 389 | 
            +
                    if prev_timestep is not None:
         | 
| 390 | 
            +
                        deprecate(
         | 
| 391 | 
            +
                            "prev_timestep",
         | 
| 392 | 
            +
                            "1.0.0",
         | 
| 393 | 
            +
                            "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
         | 
| 394 | 
            +
                        )
         | 
| 395 | 
            +
                    model_output_list = self.model_outputs
         | 
| 396 | 
            +
             | 
| 397 | 
            +
                    s0 = self.timestep_list[-1]
         | 
| 398 | 
            +
                    m0 = model_output_list[-1]
         | 
| 399 | 
            +
                    x = sample
         | 
| 400 | 
            +
             | 
| 401 | 
            +
                    if self.solver_p:
         | 
| 402 | 
            +
                        x_t = self.solver_p.step(model_output, s0, x).prev_sample
         | 
| 403 | 
            +
                        return x_t
         | 
| 404 | 
            +
             | 
| 405 | 
            +
                    sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[
         | 
| 406 | 
            +
                        self.step_index]  # pyright: ignore
         | 
| 407 | 
            +
                    alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
         | 
| 408 | 
            +
                    alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
         | 
| 409 | 
            +
             | 
| 410 | 
            +
                    lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
         | 
| 411 | 
            +
                    lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
         | 
| 412 | 
            +
             | 
| 413 | 
            +
                    h = lambda_t - lambda_s0
         | 
| 414 | 
            +
                    device = sample.device
         | 
| 415 | 
            +
             | 
| 416 | 
            +
                    rks = []
         | 
| 417 | 
            +
                    D1s = []
         | 
| 418 | 
            +
                    for i in range(1, order):
         | 
| 419 | 
            +
                        si = self.step_index - i  # pyright: ignore
         | 
| 420 | 
            +
                        mi = model_output_list[-(i + 1)]
         | 
| 421 | 
            +
                        alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
         | 
| 422 | 
            +
                        lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
         | 
| 423 | 
            +
                        rk = (lambda_si - lambda_s0) / h
         | 
| 424 | 
            +
                        rks.append(rk)
         | 
| 425 | 
            +
                        D1s.append((mi - m0) / rk)  # pyright: ignore
         | 
| 426 | 
            +
             | 
| 427 | 
            +
                    rks.append(1.0)
         | 
| 428 | 
            +
                    rks = torch.tensor(rks, device=device)
         | 
| 429 | 
            +
             | 
| 430 | 
            +
                    R = []
         | 
| 431 | 
            +
                    b = []
         | 
| 432 | 
            +
             | 
| 433 | 
            +
                    hh = -h if self.predict_x0 else h
         | 
| 434 | 
            +
                    h_phi_1 = torch.expm1(hh)  # h\phi_1(h) = e^h - 1
         | 
| 435 | 
            +
                    h_phi_k = h_phi_1 / hh - 1
         | 
| 436 | 
            +
             | 
| 437 | 
            +
                    factorial_i = 1
         | 
| 438 | 
            +
             | 
| 439 | 
            +
                    if self.config.solver_type == "bh1":
         | 
| 440 | 
            +
                        B_h = hh
         | 
| 441 | 
            +
                    elif self.config.solver_type == "bh2":
         | 
| 442 | 
            +
                        B_h = torch.expm1(hh)
         | 
| 443 | 
            +
                    else:
         | 
| 444 | 
            +
                        raise NotImplementedError()
         | 
| 445 | 
            +
             | 
| 446 | 
            +
                    for i in range(1, order + 1):
         | 
| 447 | 
            +
                        R.append(torch.pow(rks, i - 1))
         | 
| 448 | 
            +
                        b.append(h_phi_k * factorial_i / B_h)
         | 
| 449 | 
            +
                        factorial_i *= i + 1
         | 
| 450 | 
            +
                        h_phi_k = h_phi_k / hh - 1 / factorial_i
         | 
| 451 | 
            +
             | 
| 452 | 
            +
                    R = torch.stack(R)
         | 
| 453 | 
            +
                    b = torch.tensor(b, device=device)
         | 
| 454 | 
            +
             | 
| 455 | 
            +
                    if len(D1s) > 0:
         | 
| 456 | 
            +
                        D1s = torch.stack(D1s, dim=1)  # (B, K)
         | 
| 457 | 
            +
                        # for order 2, we use a simplified version
         | 
| 458 | 
            +
                        if order == 2:
         | 
| 459 | 
            +
                            rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
         | 
| 460 | 
            +
                        else:
         | 
| 461 | 
            +
                            rhos_p = torch.linalg.solve(R[:-1, :-1],
         | 
| 462 | 
            +
                                                        b[:-1]).to(device).to(x.dtype)
         | 
| 463 | 
            +
                    else:
         | 
| 464 | 
            +
                        D1s = None
         | 
| 465 | 
            +
             | 
| 466 | 
            +
                    if self.predict_x0:
         | 
| 467 | 
            +
                        x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
         | 
| 468 | 
            +
                        if D1s is not None:
         | 
| 469 | 
            +
                            pred_res = torch.einsum("k,bkc...->bc...", rhos_p,
         | 
| 470 | 
            +
                                                    D1s)  # pyright: ignore
         | 
| 471 | 
            +
                        else:
         | 
| 472 | 
            +
                            pred_res = 0
         | 
| 473 | 
            +
                        x_t = x_t_ - alpha_t * B_h * pred_res
         | 
| 474 | 
            +
                    else:
         | 
| 475 | 
            +
                        x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
         | 
| 476 | 
            +
                        if D1s is not None:
         | 
| 477 | 
            +
                            pred_res = torch.einsum("k,bkc...->bc...", rhos_p,
         | 
| 478 | 
            +
                                                    D1s)  # pyright: ignore
         | 
| 479 | 
            +
                        else:
         | 
| 480 | 
            +
                            pred_res = 0
         | 
| 481 | 
            +
                        x_t = x_t_ - sigma_t * B_h * pred_res
         | 
| 482 | 
            +
             | 
| 483 | 
            +
                    x_t = x_t.to(x.dtype)
         | 
| 484 | 
            +
                    return x_t
         | 
| 485 | 
            +
             | 
| 486 | 
            +
                def multistep_uni_c_bh_update(
         | 
| 487 | 
            +
                    self,
         | 
| 488 | 
            +
                    this_model_output: torch.Tensor,
         | 
| 489 | 
            +
                    *args,
         | 
| 490 | 
            +
                    last_sample: torch.Tensor = None,
         | 
| 491 | 
            +
                    this_sample: torch.Tensor = None,
         | 
| 492 | 
            +
                    order: int = None,  # pyright: ignore
         | 
| 493 | 
            +
                    **kwargs,
         | 
| 494 | 
            +
                ) -> torch.Tensor:
         | 
| 495 | 
            +
                    """
         | 
| 496 | 
            +
                    One step for the UniC (B(h) version).
         | 
| 497 | 
            +
             | 
| 498 | 
            +
                    Args:
         | 
| 499 | 
            +
                        this_model_output (`torch.Tensor`):
         | 
| 500 | 
            +
                            The model outputs at `x_t`.
         | 
| 501 | 
            +
                        this_timestep (`int`):
         | 
| 502 | 
            +
                            The current timestep `t`.
         | 
| 503 | 
            +
                        last_sample (`torch.Tensor`):
         | 
| 504 | 
            +
                            The generated sample before the last predictor `x_{t-1}`.
         | 
| 505 | 
            +
                        this_sample (`torch.Tensor`):
         | 
| 506 | 
            +
                            The generated sample after the last predictor `x_{t}`.
         | 
| 507 | 
            +
                        order (`int`):
         | 
| 508 | 
            +
                            The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`.
         | 
| 509 | 
            +
             | 
| 510 | 
            +
                    Returns:
         | 
| 511 | 
            +
                        `torch.Tensor`:
         | 
| 512 | 
            +
                            The corrected sample tensor at the current timestep.
         | 
| 513 | 
            +
                    """
         | 
| 514 | 
            +
                    this_timestep = args[0] if len(args) > 0 else kwargs.pop(
         | 
| 515 | 
            +
                        "this_timestep", None)
         | 
| 516 | 
            +
                    if last_sample is None:
         | 
| 517 | 
            +
                        if len(args) > 1:
         | 
| 518 | 
            +
                            last_sample = args[1]
         | 
| 519 | 
            +
                        else:
         | 
| 520 | 
            +
                            raise ValueError(
         | 
| 521 | 
            +
                                " missing`last_sample` as a required keyward argument")
         | 
| 522 | 
            +
                    if this_sample is None:
         | 
| 523 | 
            +
                        if len(args) > 2:
         | 
| 524 | 
            +
                            this_sample = args[2]
         | 
| 525 | 
            +
                        else:
         | 
| 526 | 
            +
                            raise ValueError(
         | 
| 527 | 
            +
                                " missing`this_sample` as a required keyward argument")
         | 
| 528 | 
            +
                    if order is None:
         | 
| 529 | 
            +
                        if len(args) > 3:
         | 
| 530 | 
            +
                            order = args[3]
         | 
| 531 | 
            +
                        else:
         | 
| 532 | 
            +
                            raise ValueError(
         | 
| 533 | 
            +
                                " missing`order` as a required keyward argument")
         | 
| 534 | 
            +
                    if this_timestep is not None:
         | 
| 535 | 
            +
                        deprecate(
         | 
| 536 | 
            +
                            "this_timestep",
         | 
| 537 | 
            +
                            "1.0.0",
         | 
| 538 | 
            +
                            "Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
         | 
| 539 | 
            +
                        )
         | 
| 540 | 
            +
             | 
| 541 | 
            +
                    model_output_list = self.model_outputs
         | 
| 542 | 
            +
             | 
| 543 | 
            +
                    m0 = model_output_list[-1]
         | 
| 544 | 
            +
                    x = last_sample
         | 
| 545 | 
            +
                    x_t = this_sample
         | 
| 546 | 
            +
                    model_t = this_model_output
         | 
| 547 | 
            +
             | 
| 548 | 
            +
                    sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[
         | 
| 549 | 
            +
                        self.step_index - 1]  # pyright: ignore
         | 
| 550 | 
            +
                    alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
         | 
| 551 | 
            +
                    alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
         | 
| 552 | 
            +
             | 
| 553 | 
            +
                    lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
         | 
| 554 | 
            +
                    lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
         | 
| 555 | 
            +
             | 
| 556 | 
            +
                    h = lambda_t - lambda_s0
         | 
| 557 | 
            +
                    device = this_sample.device
         | 
| 558 | 
            +
             | 
| 559 | 
            +
                    rks = []
         | 
| 560 | 
            +
                    D1s = []
         | 
| 561 | 
            +
                    for i in range(1, order):
         | 
| 562 | 
            +
                        si = self.step_index - (i + 1)  # pyright: ignore
         | 
| 563 | 
            +
                        mi = model_output_list[-(i + 1)]
         | 
| 564 | 
            +
                        alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
         | 
| 565 | 
            +
                        lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
         | 
| 566 | 
            +
                        rk = (lambda_si - lambda_s0) / h
         | 
| 567 | 
            +
                        rks.append(rk)
         | 
| 568 | 
            +
                        D1s.append((mi - m0) / rk)  # pyright: ignore
         | 
| 569 | 
            +
             | 
| 570 | 
            +
                    rks.append(1.0)
         | 
| 571 | 
            +
                    rks = torch.tensor(rks, device=device)
         | 
| 572 | 
            +
             | 
| 573 | 
            +
                    R = []
         | 
| 574 | 
            +
                    b = []
         | 
| 575 | 
            +
             | 
| 576 | 
            +
                    hh = -h if self.predict_x0 else h
         | 
| 577 | 
            +
                    h_phi_1 = torch.expm1(hh)  # h\phi_1(h) = e^h - 1
         | 
| 578 | 
            +
                    h_phi_k = h_phi_1 / hh - 1
         | 
| 579 | 
            +
             | 
| 580 | 
            +
                    factorial_i = 1
         | 
| 581 | 
            +
             | 
| 582 | 
            +
                    if self.config.solver_type == "bh1":
         | 
| 583 | 
            +
                        B_h = hh
         | 
| 584 | 
            +
                    elif self.config.solver_type == "bh2":
         | 
| 585 | 
            +
                        B_h = torch.expm1(hh)
         | 
| 586 | 
            +
                    else:
         | 
| 587 | 
            +
                        raise NotImplementedError()
         | 
| 588 | 
            +
             | 
| 589 | 
            +
                    for i in range(1, order + 1):
         | 
| 590 | 
            +
                        R.append(torch.pow(rks, i - 1))
         | 
| 591 | 
            +
                        b.append(h_phi_k * factorial_i / B_h)
         | 
| 592 | 
            +
                        factorial_i *= i + 1
         | 
| 593 | 
            +
                        h_phi_k = h_phi_k / hh - 1 / factorial_i
         | 
| 594 | 
            +
             | 
| 595 | 
            +
                    R = torch.stack(R)
         | 
| 596 | 
            +
                    b = torch.tensor(b, device=device)
         | 
| 597 | 
            +
             | 
| 598 | 
            +
                    if len(D1s) > 0:
         | 
| 599 | 
            +
                        D1s = torch.stack(D1s, dim=1)
         | 
| 600 | 
            +
                    else:
         | 
| 601 | 
            +
                        D1s = None
         | 
| 602 | 
            +
             | 
| 603 | 
            +
                    # for order 1, we use a simplified version
         | 
| 604 | 
            +
                    if order == 1:
         | 
| 605 | 
            +
                        rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
         | 
| 606 | 
            +
                    else:
         | 
| 607 | 
            +
                        rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype)
         | 
| 608 | 
            +
             | 
| 609 | 
            +
                    if self.predict_x0:
         | 
| 610 | 
            +
                        x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
         | 
| 611 | 
            +
                        if D1s is not None:
         | 
| 612 | 
            +
                            corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
         | 
| 613 | 
            +
                        else:
         | 
| 614 | 
            +
                            corr_res = 0
         | 
| 615 | 
            +
                        D1_t = model_t - m0
         | 
| 616 | 
            +
                        x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)
         | 
| 617 | 
            +
                    else:
         | 
| 618 | 
            +
                        x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
         | 
| 619 | 
            +
                        if D1s is not None:
         | 
| 620 | 
            +
                            corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
         | 
| 621 | 
            +
                        else:
         | 
| 622 | 
            +
                            corr_res = 0
         | 
| 623 | 
            +
                        D1_t = model_t - m0
         | 
| 624 | 
            +
                        x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t)
         | 
| 625 | 
            +
                    x_t = x_t.to(x.dtype)
         | 
| 626 | 
            +
                    return x_t
         | 
| 627 | 
            +
             | 
| 628 | 
            +
                def index_for_timestep(self, timestep, schedule_timesteps=None):
         | 
| 629 | 
            +
                    if schedule_timesteps is None:
         | 
| 630 | 
            +
                        schedule_timesteps = self.timesteps
         | 
| 631 | 
            +
             | 
| 632 | 
            +
                    indices = (schedule_timesteps == timestep).nonzero()
         | 
| 633 | 
            +
             | 
| 634 | 
            +
                    # The sigma index that is taken for the **very** first `step`
         | 
| 635 | 
            +
                    # is always the second index (or the last index if there is only 1)
         | 
| 636 | 
            +
                    # This way we can ensure we don't accidentally skip a sigma in
         | 
| 637 | 
            +
                    # case we start in the middle of the denoising schedule (e.g. for image-to-image)
         | 
| 638 | 
            +
                    pos = 1 if len(indices) > 1 else 0
         | 
| 639 | 
            +
             | 
| 640 | 
            +
                    return indices[pos].item()
         | 
| 641 | 
            +
             | 
| 642 | 
            +
                # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
         | 
| 643 | 
            +
                def _init_step_index(self, timestep):
         | 
| 644 | 
            +
                    """
         | 
| 645 | 
            +
                    Initialize the step_index counter for the scheduler.
         | 
| 646 | 
            +
                    """
         | 
| 647 | 
            +
             | 
| 648 | 
            +
                    if self.begin_index is None:
         | 
| 649 | 
            +
                        if isinstance(timestep, torch.Tensor):
         | 
| 650 | 
            +
                            timestep = timestep.to(self.timesteps.device)
         | 
| 651 | 
            +
                        self._step_index = self.index_for_timestep(timestep)
         | 
| 652 | 
            +
                    else:
         | 
| 653 | 
            +
                        self._step_index = self._begin_index
         | 
| 654 | 
            +
             | 
| 655 | 
            +
                def step(self,
         | 
| 656 | 
            +
                         model_output: torch.Tensor,
         | 
| 657 | 
            +
                         timestep: Union[int, torch.Tensor],
         | 
| 658 | 
            +
                         sample: torch.Tensor,
         | 
| 659 | 
            +
                         return_dict: bool = True,
         | 
| 660 | 
            +
                         generator=None) -> Union[SchedulerOutput, Tuple]:
         | 
| 661 | 
            +
                    """
         | 
| 662 | 
            +
                    Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
         | 
| 663 | 
            +
                    the multistep UniPC.
         | 
| 664 | 
            +
             | 
| 665 | 
            +
                    Args:
         | 
| 666 | 
            +
                        model_output (`torch.Tensor`):
         | 
| 667 | 
            +
                            The direct output from learned diffusion model.
         | 
| 668 | 
            +
                        timestep (`int`):
         | 
| 669 | 
            +
                            The current discrete timestep in the diffusion chain.
         | 
| 670 | 
            +
                        sample (`torch.Tensor`):
         | 
| 671 | 
            +
                            A current instance of a sample created by the diffusion process.
         | 
| 672 | 
            +
                        return_dict (`bool`):
         | 
| 673 | 
            +
                            Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
         | 
| 674 | 
            +
             | 
| 675 | 
            +
                    Returns:
         | 
| 676 | 
            +
                        [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
         | 
| 677 | 
            +
                            If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
         | 
| 678 | 
            +
                            tuple is returned where the first element is the sample tensor.
         | 
| 679 | 
            +
             | 
| 680 | 
            +
                    """
         | 
| 681 | 
            +
                    if self.num_inference_steps is None:
         | 
| 682 | 
            +
                        raise ValueError(
         | 
| 683 | 
            +
                            "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
         | 
| 684 | 
            +
                        )
         | 
| 685 | 
            +
             | 
| 686 | 
            +
                    if self.step_index is None:
         | 
| 687 | 
            +
                        self._init_step_index(timestep)
         | 
| 688 | 
            +
             | 
| 689 | 
            +
                    use_corrector = (
         | 
| 690 | 
            +
                        self.step_index > 0 and
         | 
| 691 | 
            +
                        self.step_index - 1 not in self.disable_corrector and
         | 
| 692 | 
            +
                        self.last_sample is not None  # pyright: ignore
         | 
| 693 | 
            +
                    )
         | 
| 694 | 
            +
             | 
| 695 | 
            +
                    model_output_convert = self.convert_model_output(
         | 
| 696 | 
            +
                        model_output, sample=sample)
         | 
| 697 | 
            +
                    if use_corrector:
         | 
| 698 | 
            +
                        sample = self.multistep_uni_c_bh_update(
         | 
| 699 | 
            +
                            this_model_output=model_output_convert,
         | 
| 700 | 
            +
                            last_sample=self.last_sample,
         | 
| 701 | 
            +
                            this_sample=sample,
         | 
| 702 | 
            +
                            order=self.this_order,
         | 
| 703 | 
            +
                        )
         | 
| 704 | 
            +
             | 
| 705 | 
            +
                    for i in range(self.config.solver_order - 1):
         | 
| 706 | 
            +
                        self.model_outputs[i] = self.model_outputs[i + 1]
         | 
| 707 | 
            +
                        self.timestep_list[i] = self.timestep_list[i + 1]
         | 
| 708 | 
            +
             | 
| 709 | 
            +
                    self.model_outputs[-1] = model_output_convert
         | 
| 710 | 
            +
                    self.timestep_list[-1] = timestep  # pyright: ignore
         | 
| 711 | 
            +
             | 
| 712 | 
            +
                    if self.config.lower_order_final:
         | 
| 713 | 
            +
                        this_order = min(self.config.solver_order,
         | 
| 714 | 
            +
                                         len(self.timesteps) -
         | 
| 715 | 
            +
                                         self.step_index)  # pyright: ignore
         | 
| 716 | 
            +
                    else:
         | 
| 717 | 
            +
                        this_order = self.config.solver_order
         | 
| 718 | 
            +
             | 
| 719 | 
            +
                    self.this_order = min(this_order,
         | 
| 720 | 
            +
                                          self.lower_order_nums + 1)  # warmup for multistep
         | 
| 721 | 
            +
                    assert self.this_order > 0
         | 
| 722 | 
            +
             | 
| 723 | 
            +
                    self.last_sample = sample
         | 
| 724 | 
            +
                    prev_sample = self.multistep_uni_p_bh_update(
         | 
| 725 | 
            +
                        model_output=model_output,  # pass the original non-converted model output, in case solver-p is used
         | 
| 726 | 
            +
                        sample=sample,
         | 
| 727 | 
            +
                        order=self.this_order,
         | 
| 728 | 
            +
                    )
         | 
| 729 | 
            +
             | 
| 730 | 
            +
                    if self.lower_order_nums < self.config.solver_order:
         | 
| 731 | 
            +
                        self.lower_order_nums += 1
         | 
| 732 | 
            +
             | 
| 733 | 
            +
                    # upon completion increase step index by one
         | 
| 734 | 
            +
                    self._step_index += 1  # pyright: ignore
         | 
| 735 | 
            +
             | 
| 736 | 
            +
                    if not return_dict:
         | 
| 737 | 
            +
                        return (prev_sample,)
         | 
| 738 | 
            +
             | 
| 739 | 
            +
                    return SchedulerOutput(prev_sample=prev_sample)
         | 
| 740 | 
            +
             | 
| 741 | 
            +
                def scale_model_input(self, sample: torch.Tensor, *args,
         | 
| 742 | 
            +
                                      **kwargs) -> torch.Tensor:
         | 
| 743 | 
            +
                    """
         | 
| 744 | 
            +
                    Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
         | 
| 745 | 
            +
                    current timestep.
         | 
| 746 | 
            +
             | 
| 747 | 
            +
                    Args:
         | 
| 748 | 
            +
                        sample (`torch.Tensor`):
         | 
| 749 | 
            +
                            The input sample.
         | 
| 750 | 
            +
             | 
| 751 | 
            +
                    Returns:
         | 
| 752 | 
            +
                        `torch.Tensor`:
         | 
| 753 | 
            +
                            A scaled input sample.
         | 
| 754 | 
            +
                    """
         | 
| 755 | 
            +
                    return sample
         | 
| 756 | 
            +
             | 
| 757 | 
            +
                # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
         | 
| 758 | 
            +
                def add_noise(
         | 
| 759 | 
            +
                    self,
         | 
| 760 | 
            +
                    original_samples: torch.Tensor,
         | 
| 761 | 
            +
                    noise: torch.Tensor,
         | 
| 762 | 
            +
                    timesteps: torch.IntTensor,
         | 
| 763 | 
            +
                ) -> torch.Tensor:
         | 
| 764 | 
            +
                    # Make sure sigmas and timesteps have the same device and dtype as original_samples
         | 
| 765 | 
            +
                    sigmas = self.sigmas.to(
         | 
| 766 | 
            +
                        device=original_samples.device, dtype=original_samples.dtype)
         | 
| 767 | 
            +
                    if original_samples.device.type == "mps" and torch.is_floating_point(
         | 
| 768 | 
            +
                            timesteps):
         | 
| 769 | 
            +
                        # mps does not support float64
         | 
| 770 | 
            +
                        schedule_timesteps = self.timesteps.to(
         | 
| 771 | 
            +
                            original_samples.device, dtype=torch.float32)
         | 
| 772 | 
            +
                        timesteps = timesteps.to(
         | 
| 773 | 
            +
                            original_samples.device, dtype=torch.float32)
         | 
| 774 | 
            +
                    else:
         | 
| 775 | 
            +
                        schedule_timesteps = self.timesteps.to(original_samples.device)
         | 
| 776 | 
            +
                        timesteps = timesteps.to(original_samples.device)
         | 
| 777 | 
            +
             | 
| 778 | 
            +
                    # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
         | 
| 779 | 
            +
                    if self.begin_index is None:
         | 
| 780 | 
            +
                        step_indices = [
         | 
| 781 | 
            +
                            self.index_for_timestep(t, schedule_timesteps)
         | 
| 782 | 
            +
                            for t in timesteps
         | 
| 783 | 
            +
                        ]
         | 
| 784 | 
            +
                    elif self.step_index is not None:
         | 
| 785 | 
            +
                        # add_noise is called after first denoising step (for inpainting)
         | 
| 786 | 
            +
                        step_indices = [self.step_index] * timesteps.shape[0]
         | 
| 787 | 
            +
                    else:
         | 
| 788 | 
            +
                        # add noise is called before first denoising step to create initial latent(img2img)
         | 
| 789 | 
            +
                        step_indices = [self.begin_index] * timesteps.shape[0]
         | 
| 790 | 
            +
             | 
| 791 | 
            +
                    sigma = sigmas[step_indices].flatten()
         | 
| 792 | 
            +
                    while len(sigma.shape) < len(original_samples.shape):
         | 
| 793 | 
            +
                        sigma = sigma.unsqueeze(-1)
         | 
| 794 | 
            +
             | 
| 795 | 
            +
                    alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
         | 
| 796 | 
            +
                    noisy_samples = alpha_t * original_samples + sigma_t * noise
         | 
| 797 | 
            +
                    return noisy_samples
         | 
| 798 | 
            +
             | 
| 799 | 
            +
                def __len__(self):
         | 
| 800 | 
            +
                    return self.config.num_train_timesteps
         | 
    	
        humo/models/utils/utils.py
    ADDED
    
    | @@ -0,0 +1,58 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
         | 
| 2 | 
            +
            import argparse
         | 
| 3 | 
            +
            import binascii
         | 
| 4 | 
            +
            import os
         | 
| 5 | 
            +
            import os.path as osp
         | 
| 6 | 
            +
            import json
         | 
| 7 | 
            +
            from omegaconf import OmegaConf
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import imageio
         | 
| 10 | 
            +
            import torch
         | 
| 11 | 
            +
            import torchvision
         | 
| 12 | 
            +
            from moviepy.editor import AudioFileClip, VideoClip
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            __all__ = ['tensor_to_video', 'prepare_json_dataset']
         | 
| 15 | 
            +
                
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            def tensor_to_video(tensor, output_video_path, input_audio_path, fps=25):
         | 
| 18 | 
            +
                """
         | 
| 19 | 
            +
                Converts a Tensor with shape [c, f, h, w] into a video and adds an audio track from the specified audio file.
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                Args:
         | 
| 22 | 
            +
                    tensor (numpy): The Tensor to be converted, shaped [f, h, w, c].
         | 
| 23 | 
            +
                    output_video_path (str): The file path where the output video will be saved.
         | 
| 24 | 
            +
                    input_audio_path (str): The path to the audio file (WAV file) that contains the audio track to be added.
         | 
| 25 | 
            +
                    fps (int): The frame rate of the output video. Default is 30 fps.
         | 
| 26 | 
            +
                """
         | 
| 27 | 
            +
                def make_frame(t):
         | 
| 28 | 
            +
                    frame_index = min(int(t * fps), tensor.shape[0] - 1)
         | 
| 29 | 
            +
                    return tensor[frame_index]
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                video_duration = tensor.shape[0] / fps
         | 
| 32 | 
            +
                audio_clip = AudioFileClip(input_audio_path)
         | 
| 33 | 
            +
                audio_duration = audio_clip.duration
         | 
| 34 | 
            +
                final_duration = min(video_duration, audio_duration)
         | 
| 35 | 
            +
                audio_clip = audio_clip.subclip(0, final_duration)
         | 
| 36 | 
            +
                new_video_clip = VideoClip(make_frame, duration=final_duration)
         | 
| 37 | 
            +
                new_video_clip = new_video_clip.set_audio(audio_clip)
         | 
| 38 | 
            +
                new_video_clip.write_videofile(output_video_path, fps=fps, audio_codec="aac")
         | 
| 39 | 
            +
             | 
| 40 | 
            +
             | 
| 41 | 
            +
            def prepare_json_dataset(json_path):
         | 
| 42 | 
            +
                samples = []
         | 
| 43 | 
            +
                with open(json_path, "rb") as f:
         | 
| 44 | 
            +
                    data = json.load(f)
         | 
| 45 | 
            +
                for itemname, row in data.items():
         | 
| 46 | 
            +
                    text = row['prompt'].strip().replace("_", " ").strip('"')
         | 
| 47 | 
            +
                    audio_path = row['audio_path']
         | 
| 48 | 
            +
                    ref_img_path = [x for x in row['img_paths']]
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                    samples.append({
         | 
| 51 | 
            +
                        "text": text,
         | 
| 52 | 
            +
                        "ref_img": ref_img_path,
         | 
| 53 | 
            +
                        "audio": audio_path,
         | 
| 54 | 
            +
                        "itemname": itemname
         | 
| 55 | 
            +
                    })
         | 
| 56 | 
            +
                samples = OmegaConf.create(samples)
         | 
| 57 | 
            +
                
         | 
| 58 | 
            +
                return samples
         | 
    	
        humo/models/wan_modules/__init__.py
    ADDED
    
    | @@ -0,0 +1,16 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .attention import flash_attention
         | 
| 2 | 
            +
            from .model import WanModel
         | 
| 3 | 
            +
            from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model
         | 
| 4 | 
            +
            from .tokenizers import HuggingfaceTokenizer
         | 
| 5 | 
            +
            from .vae import WanVAE
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            __all__ = [
         | 
| 8 | 
            +
                'WanVAE',
         | 
| 9 | 
            +
                'WanModel',
         | 
| 10 | 
            +
                'T5Model',
         | 
| 11 | 
            +
                'T5Encoder',
         | 
| 12 | 
            +
                'T5Decoder',
         | 
| 13 | 
            +
                'T5EncoderModel',
         | 
| 14 | 
            +
                'HuggingfaceTokenizer',
         | 
| 15 | 
            +
                'flash_attention',
         | 
| 16 | 
            +
            ]
         | 
    	
        humo/models/wan_modules/attention.py
    ADDED
    
    | @@ -0,0 +1,256 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
         | 
| 2 | 
            +
            import warnings
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            from typing import Optional, Tuple
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            try:
         | 
| 7 | 
            +
                import flash_attn_interface
         | 
| 8 | 
            +
                FLASH_ATTN_3_AVAILABLE = True
         | 
| 9 | 
            +
            except ModuleNotFoundError:
         | 
| 10 | 
            +
                FLASH_ATTN_3_AVAILABLE = False
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            try:
         | 
| 13 | 
            +
                import flash_attn
         | 
| 14 | 
            +
                FLASH_ATTN_2_AVAILABLE = True
         | 
| 15 | 
            +
            except ModuleNotFoundError:
         | 
| 16 | 
            +
                FLASH_ATTN_2_AVAILABLE = False
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            __all__ = [
         | 
| 20 | 
            +
                'flash_attention',
         | 
| 21 | 
            +
                'attention',
         | 
| 22 | 
            +
            ]
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            # ---------------------------
         | 
| 26 | 
            +
            # Custom op + fake kernel
         | 
| 27 | 
            +
            # ---------------------------
         | 
| 28 | 
            +
            from typing import Optional, Sequence  # <- add Sequence
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            # ... imports unchanged ...
         | 
| 31 | 
            +
            from typing import Optional, Sequence
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            @torch.library.custom_op("wan::flash_attention", mutates_args=())
         | 
| 34 | 
            +
            def _wan_flash_attention_op(
         | 
| 35 | 
            +
                q: torch.Tensor,
         | 
| 36 | 
            +
                k: torch.Tensor,
         | 
| 37 | 
            +
                v: torch.Tensor,
         | 
| 38 | 
            +
                q_lens: Optional[torch.Tensor] = None,
         | 
| 39 | 
            +
                k_lens: Optional[torch.Tensor] = None,
         | 
| 40 | 
            +
                dropout_p: float = 0.0,
         | 
| 41 | 
            +
                softmax_scale: Optional[float] = None,
         | 
| 42 | 
            +
                q_scale: Optional[float] = None,
         | 
| 43 | 
            +
                causal: bool = False,
         | 
| 44 | 
            +
                # IMPORTANT: schema-friendly default (None), not a tuple
         | 
| 45 | 
            +
                window_size: Optional[Sequence[int]] = None,
         | 
| 46 | 
            +
                deterministic: bool = False,
         | 
| 47 | 
            +
                dtype: torch.dtype = torch.bfloat16,
         | 
| 48 | 
            +
                version: Optional[int] = None,
         | 
| 49 | 
            +
            ) -> torch.Tensor:
         | 
| 50 | 
            +
                half_dtypes = (torch.float16, torch.bfloat16)
         | 
| 51 | 
            +
                assert dtype in half_dtypes
         | 
| 52 | 
            +
                assert q.size(-1) <= 256
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                # normalize window_size to a 2-tuple for FA2 API
         | 
| 55 | 
            +
                if window_size is None:
         | 
| 56 | 
            +
                    ws = (-1, -1)
         | 
| 57 | 
            +
                else:
         | 
| 58 | 
            +
                    ws = tuple(window_size)
         | 
| 59 | 
            +
                    if len(ws) != 2:
         | 
| 60 | 
            +
                        raise ValueError(f"window_size must have length 2; got {window_size!r}")
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                b, lq, nheads = q.shape[0], q.shape[1], q.shape[2]
         | 
| 63 | 
            +
                lk = k.shape[1]
         | 
| 64 | 
            +
                out_dtype = q.dtype
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                def half(x: torch.Tensor) -> torch.Tensor:
         | 
| 67 | 
            +
                    return x if x.dtype in half_dtypes else x.to(dtype)
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                # --- preprocess (unchanged) ---
         | 
| 70 | 
            +
                if q_lens is None:
         | 
| 71 | 
            +
                    q_flat = half(q.flatten(0, 1))
         | 
| 72 | 
            +
                    q_lens = torch.tensor([lq] * b, dtype=torch.int32)
         | 
| 73 | 
            +
                else:
         | 
| 74 | 
            +
                    q_flat = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                if k_lens is None:
         | 
| 77 | 
            +
                    k_flat = half(k.flatten(0, 1))
         | 
| 78 | 
            +
                    v_flat = half(v.flatten(0, 1))
         | 
| 79 | 
            +
                    k_lens = torch.tensor([lk] * b, dtype=torch.int32)
         | 
| 80 | 
            +
                else:
         | 
| 81 | 
            +
                    k_flat = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
         | 
| 82 | 
            +
                    v_flat = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                q_flat = q_flat.to(v_flat.dtype); k_flat = k_flat.to(v_flat.dtype)
         | 
| 85 | 
            +
                if q_scale is not None:
         | 
| 86 | 
            +
                    q_flat = q_flat * q_scale
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
         | 
| 89 | 
            +
                    warnings.warn('Flash attention 3 is not available, use flash attention 2 instead.')
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                if FLASH_ATTN_3_AVAILABLE:
         | 
| 92 | 
            +
                    ret = flash_attn_interface.flash_attn_varlen_func(
         | 
| 93 | 
            +
                        q=q_flat,
         | 
| 94 | 
            +
                        k=k_flat,
         | 
| 95 | 
            +
                        v=v_flat,
         | 
| 96 | 
            +
                        cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32).to(q_flat.device, non_blocking=True),
         | 
| 97 | 
            +
                        cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32).to(k_flat.device, non_blocking=True),
         | 
| 98 | 
            +
                        seqused_q=None,
         | 
| 99 | 
            +
                        seqused_k=None,
         | 
| 100 | 
            +
                        max_seqlen_q=lq,
         | 
| 101 | 
            +
                        max_seqlen_k=lk,
         | 
| 102 | 
            +
                        softmax_scale=softmax_scale,
         | 
| 103 | 
            +
                        causal=causal,
         | 
| 104 | 
            +
                        deterministic=deterministic,
         | 
| 105 | 
            +
                    )
         | 
| 106 | 
            +
                    out0 = ret[0] if isinstance(ret, (tuple, list)) else ret
         | 
| 107 | 
            +
                    total_q = b * lq
         | 
| 108 | 
            +
                    if out0.dim() != 3:
         | 
| 109 | 
            +
                        raise RuntimeError(f"Unexpected FA3 output rank {out0.dim()} shape={tuple(out0.shape)}")
         | 
| 110 | 
            +
                    if out0.shape[0] == total_q:
         | 
| 111 | 
            +
                        out_flat = out0
         | 
| 112 | 
            +
                    elif out0.shape[0] == nheads and out0.shape[1] == total_q:
         | 
| 113 | 
            +
                        out_flat = out0.transpose(0, 1).contiguous()
         | 
| 114 | 
            +
                    else:
         | 
| 115 | 
            +
                        raise RuntimeError(f"Unexpected FA3 output shape {tuple(out0.shape)}")
         | 
| 116 | 
            +
                    out = out_flat.unflatten(0, (b, lq))
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                elif FLASH_ATTN_2_AVAILABLE:
         | 
| 119 | 
            +
                    out = flash_attn.flash_attn_varlen_func(
         | 
| 120 | 
            +
                        q=q_flat,
         | 
| 121 | 
            +
                        k=k_flat,
         | 
| 122 | 
            +
                        v=v_flat,
         | 
| 123 | 
            +
                        cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32).to(q_flat.device, non_blocking=True),
         | 
| 124 | 
            +
                        cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32).to(q_flat.device, non_blocking=True),
         | 
| 125 | 
            +
                        max_seqlen_q=lq,
         | 
| 126 | 
            +
                        max_seqlen_k=lk,
         | 
| 127 | 
            +
                        dropout_p=dropout_p,
         | 
| 128 | 
            +
                        softmax_scale=softmax_scale,
         | 
| 129 | 
            +
                        causal=causal,
         | 
| 130 | 
            +
                        window_size=ws,                 # <- pass 2-tuple
         | 
| 131 | 
            +
                        deterministic=deterministic,
         | 
| 132 | 
            +
                    ).unflatten(0, (b, lq))
         | 
| 133 | 
            +
                else:
         | 
| 134 | 
            +
                    q_s = q.transpose(1, 2).to(dtype)
         | 
| 135 | 
            +
                    k_s = k.transpose(1, 2).to(dtype)
         | 
| 136 | 
            +
                    v_s = v.transpose(1, 2).to(dtype)
         | 
| 137 | 
            +
                    out = torch.nn.functional.scaled_dot_product_attention(
         | 
| 138 | 
            +
                        q_s, k_s, v_s, attn_mask=None, is_causal=causal, dropout_p=dropout_p
         | 
| 139 | 
            +
                    ).transpose(1, 2).contiguous()
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                return out.to(out_dtype)
         | 
| 142 | 
            +
             | 
| 143 | 
            +
            @_wan_flash_attention_op.register_fake
         | 
| 144 | 
            +
            def _wan_flash_attention_op_fake(
         | 
| 145 | 
            +
                q,
         | 
| 146 | 
            +
                k,
         | 
| 147 | 
            +
                v,
         | 
| 148 | 
            +
                q_lens=None,
         | 
| 149 | 
            +
                k_lens=None,
         | 
| 150 | 
            +
                dropout_p: float = 0.0,
         | 
| 151 | 
            +
                softmax_scale=None,
         | 
| 152 | 
            +
                q_scale=None,
         | 
| 153 | 
            +
                causal: bool = False,
         | 
| 154 | 
            +
                window_size: Optional[Sequence[int]] = None,
         | 
| 155 | 
            +
                deterministic: bool = False,
         | 
| 156 | 
            +
                dtype: torch.dtype = torch.bfloat16,
         | 
| 157 | 
            +
                version: Optional[int] = None,
         | 
| 158 | 
            +
            ):
         | 
| 159 | 
            +
                # Match output shape: (B, Lq, Nq, Dh_v) and keep the SAME fake device as `q`
         | 
| 160 | 
            +
                B, Lq, Nq, _ = q.shape
         | 
| 161 | 
            +
                Dh_v = v.shape[-1]
         | 
| 162 | 
            +
                return q.new_empty((B, Lq, Nq, Dh_v), dtype=q.dtype)
         | 
| 163 | 
            +
             | 
| 164 | 
            +
             | 
| 165 | 
            +
             | 
| 166 | 
            +
            # ---------------------------
         | 
| 167 | 
            +
            # Public API (unchanged signature)
         | 
| 168 | 
            +
            # ---------------------------
         | 
| 169 | 
            +
            def flash_attention(
         | 
| 170 | 
            +
                q,
         | 
| 171 | 
            +
                k,
         | 
| 172 | 
            +
                v,
         | 
| 173 | 
            +
                q_lens=None,
         | 
| 174 | 
            +
                k_lens=None,
         | 
| 175 | 
            +
                dropout_p=0.,
         | 
| 176 | 
            +
                softmax_scale=None,
         | 
| 177 | 
            +
                q_scale=None,
         | 
| 178 | 
            +
                causal=False,
         | 
| 179 | 
            +
                window_size=(-1, -1),
         | 
| 180 | 
            +
                deterministic=False,
         | 
| 181 | 
            +
                dtype=torch.bfloat16,
         | 
| 182 | 
            +
                version=None,
         | 
| 183 | 
            +
            ):
         | 
| 184 | 
            +
                """
         | 
| 185 | 
            +
                q:              [B, Lq, Nq, C1].
         | 
| 186 | 
            +
                k:              [B, Lk, Nk, C1].
         | 
| 187 | 
            +
                v:              [B, Lk, Nk, C2]. Nq must be divisible by Nk.
         | 
| 188 | 
            +
                q_lens:         [B].
         | 
| 189 | 
            +
                k_lens:         [B].
         | 
| 190 | 
            +
                dropout_p:      float. Dropout probability.
         | 
| 191 | 
            +
                softmax_scale:  float. The scaling of QK^T before applying softmax.
         | 
| 192 | 
            +
                causal:         bool. Whether to apply causal attention mask.
         | 
| 193 | 
            +
                window_size:    (left right). If not (-1, -1), apply sliding window local attention.
         | 
| 194 | 
            +
                deterministic:  bool. If True, slightly slower and uses more memory.
         | 
| 195 | 
            +
                dtype:          torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
         | 
| 196 | 
            +
                """
         | 
| 197 | 
            +
                # Simply delegate to the custom op so Dynamo/AOT treats it as a single node;
         | 
| 198 | 
            +
                # our eager kernel inside _wan_flash_attention_op keeps the original behavior.
         | 
| 199 | 
            +
                return _wan_flash_attention_op(
         | 
| 200 | 
            +
                    q, k, v,
         | 
| 201 | 
            +
                    q_lens=q_lens,
         | 
| 202 | 
            +
                    k_lens=k_lens,
         | 
| 203 | 
            +
                    dropout_p=dropout_p,
         | 
| 204 | 
            +
                    softmax_scale=softmax_scale,
         | 
| 205 | 
            +
                    q_scale=q_scale,
         | 
| 206 | 
            +
                    causal=causal,
         | 
| 207 | 
            +
                    window_size=window_size,
         | 
| 208 | 
            +
                    deterministic=deterministic,
         | 
| 209 | 
            +
                    dtype=dtype,
         | 
| 210 | 
            +
                    version=version,
         | 
| 211 | 
            +
                )
         | 
| 212 | 
            +
             | 
| 213 | 
            +
             | 
| 214 | 
            +
            def attention(
         | 
| 215 | 
            +
                q,
         | 
| 216 | 
            +
                k,
         | 
| 217 | 
            +
                v,
         | 
| 218 | 
            +
                q_lens=None,
         | 
| 219 | 
            +
                k_lens=None,
         | 
| 220 | 
            +
                dropout_p=0.,
         | 
| 221 | 
            +
                softmax_scale=None,
         | 
| 222 | 
            +
                q_scale=None,
         | 
| 223 | 
            +
                causal=False,
         | 
| 224 | 
            +
                window_size=(-1, -1),
         | 
| 225 | 
            +
                deterministic=False,
         | 
| 226 | 
            +
                dtype=torch.bfloat16,
         | 
| 227 | 
            +
                fa_version=None,
         | 
| 228 | 
            +
            ):
         | 
| 229 | 
            +
                if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
         | 
| 230 | 
            +
                    return flash_attention(
         | 
| 231 | 
            +
                        q=q,
         | 
| 232 | 
            +
                        k=k,
         | 
| 233 | 
            +
                        v=v,
         | 
| 234 | 
            +
                        q_lens=q_lens,
         | 
| 235 | 
            +
                        k_lens=k_lens,
         | 
| 236 | 
            +
                        dropout_p=dropout_p,
         | 
| 237 | 
            +
                        softmax_scale=softmax_scale,
         | 
| 238 | 
            +
                        q_scale=q_scale,
         | 
| 239 | 
            +
                        causal=causal,
         | 
| 240 | 
            +
                        window_size=window_size,
         | 
| 241 | 
            +
                        deterministic=deterministic,
         | 
| 242 | 
            +
                        dtype=dtype,
         | 
| 243 | 
            +
                        version=fa_version,
         | 
| 244 | 
            +
                    )
         | 
| 245 | 
            +
                else:
         | 
| 246 | 
            +
                    if q_lens is not None or k_lens is not None:
         | 
| 247 | 
            +
                        warnings.warn(
         | 
| 248 | 
            +
                            'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.'
         | 
| 249 | 
            +
                        )
         | 
| 250 | 
            +
                    q_ = q.transpose(1, 2).to(dtype)
         | 
| 251 | 
            +
                    k_ = k.transpose(1, 2).to(dtype)
         | 
| 252 | 
            +
                    v_ = v.transpose(1, 2).to(dtype)
         | 
| 253 | 
            +
                    out = torch.nn.functional.scaled_dot_product_attention(
         | 
| 254 | 
            +
                        q_, k_, v_, attn_mask=None, is_causal=causal, dropout_p=dropout_p
         | 
| 255 | 
            +
                    )
         | 
| 256 | 
            +
                    return out.transpose(1, 2).contiguous()
         | 
    	
        humo/models/wan_modules/clip.py
    ADDED
    
    | @@ -0,0 +1,542 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip''
         | 
| 2 | 
            +
            # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
         | 
| 3 | 
            +
            import logging
         | 
| 4 | 
            +
            import math
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import torch
         | 
| 7 | 
            +
            import torch.nn as nn
         | 
| 8 | 
            +
            import torch.nn.functional as F
         | 
| 9 | 
            +
            import torchvision.transforms as T
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            from .attention import flash_attention
         | 
| 12 | 
            +
            from .tokenizers import HuggingfaceTokenizer
         | 
| 13 | 
            +
            from .xlm_roberta import XLMRoberta
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            __all__ = [
         | 
| 16 | 
            +
                'XLMRobertaCLIP',
         | 
| 17 | 
            +
                'clip_xlm_roberta_vit_h_14',
         | 
| 18 | 
            +
                'CLIPModel',
         | 
| 19 | 
            +
            ]
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            def pos_interpolate(pos, seq_len):
         | 
| 23 | 
            +
                if pos.size(1) == seq_len:
         | 
| 24 | 
            +
                    return pos
         | 
| 25 | 
            +
                else:
         | 
| 26 | 
            +
                    src_grid = int(math.sqrt(pos.size(1)))
         | 
| 27 | 
            +
                    tar_grid = int(math.sqrt(seq_len))
         | 
| 28 | 
            +
                    n = pos.size(1) - src_grid * src_grid
         | 
| 29 | 
            +
                    return torch.cat([
         | 
| 30 | 
            +
                        pos[:, :n],
         | 
| 31 | 
            +
                        F.interpolate(
         | 
| 32 | 
            +
                            pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(
         | 
| 33 | 
            +
                                0, 3, 1, 2),
         | 
| 34 | 
            +
                            size=(tar_grid, tar_grid),
         | 
| 35 | 
            +
                            mode='bicubic',
         | 
| 36 | 
            +
                            align_corners=False).flatten(2).transpose(1, 2)
         | 
| 37 | 
            +
                    ],
         | 
| 38 | 
            +
                                     dim=1)
         | 
| 39 | 
            +
             | 
| 40 | 
            +
             | 
| 41 | 
            +
            class QuickGELU(nn.Module):
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                def forward(self, x):
         | 
| 44 | 
            +
                    return x * torch.sigmoid(1.702 * x)
         | 
| 45 | 
            +
             | 
| 46 | 
            +
             | 
| 47 | 
            +
            class LayerNorm(nn.LayerNorm):
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                def forward(self, x):
         | 
| 50 | 
            +
                    return super().forward(x.float()).type_as(x)
         | 
| 51 | 
            +
             | 
| 52 | 
            +
             | 
| 53 | 
            +
            class SelfAttention(nn.Module):
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                def __init__(self,
         | 
| 56 | 
            +
                             dim,
         | 
| 57 | 
            +
                             num_heads,
         | 
| 58 | 
            +
                             causal=False,
         | 
| 59 | 
            +
                             attn_dropout=0.0,
         | 
| 60 | 
            +
                             proj_dropout=0.0):
         | 
| 61 | 
            +
                    assert dim % num_heads == 0
         | 
| 62 | 
            +
                    super().__init__()
         | 
| 63 | 
            +
                    self.dim = dim
         | 
| 64 | 
            +
                    self.num_heads = num_heads
         | 
| 65 | 
            +
                    self.head_dim = dim // num_heads
         | 
| 66 | 
            +
                    self.causal = causal
         | 
| 67 | 
            +
                    self.attn_dropout = attn_dropout
         | 
| 68 | 
            +
                    self.proj_dropout = proj_dropout
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                    # layers
         | 
| 71 | 
            +
                    self.to_qkv = nn.Linear(dim, dim * 3)
         | 
| 72 | 
            +
                    self.proj = nn.Linear(dim, dim)
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                def forward(self, x):
         | 
| 75 | 
            +
                    """
         | 
| 76 | 
            +
                    x:   [B, L, C].
         | 
| 77 | 
            +
                    """
         | 
| 78 | 
            +
                    b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    # compute query, key, value
         | 
| 81 | 
            +
                    q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                    # compute attention
         | 
| 84 | 
            +
                    p = self.attn_dropout if self.training else 0.0
         | 
| 85 | 
            +
                    x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2)
         | 
| 86 | 
            +
                    x = x.reshape(b, s, c)
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                    # output
         | 
| 89 | 
            +
                    x = self.proj(x)
         | 
| 90 | 
            +
                    x = F.dropout(x, self.proj_dropout, self.training)
         | 
| 91 | 
            +
                    return x
         | 
| 92 | 
            +
             | 
| 93 | 
            +
             | 
| 94 | 
            +
            class SwiGLU(nn.Module):
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                def __init__(self, dim, mid_dim):
         | 
| 97 | 
            +
                    super().__init__()
         | 
| 98 | 
            +
                    self.dim = dim
         | 
| 99 | 
            +
                    self.mid_dim = mid_dim
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                    # layers
         | 
| 102 | 
            +
                    self.fc1 = nn.Linear(dim, mid_dim)
         | 
| 103 | 
            +
                    self.fc2 = nn.Linear(dim, mid_dim)
         | 
| 104 | 
            +
                    self.fc3 = nn.Linear(mid_dim, dim)
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                def forward(self, x):
         | 
| 107 | 
            +
                    x = F.silu(self.fc1(x)) * self.fc2(x)
         | 
| 108 | 
            +
                    x = self.fc3(x)
         | 
| 109 | 
            +
                    return x
         | 
| 110 | 
            +
             | 
| 111 | 
            +
             | 
| 112 | 
            +
            class AttentionBlock(nn.Module):
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                def __init__(self,
         | 
| 115 | 
            +
                             dim,
         | 
| 116 | 
            +
                             mlp_ratio,
         | 
| 117 | 
            +
                             num_heads,
         | 
| 118 | 
            +
                             post_norm=False,
         | 
| 119 | 
            +
                             causal=False,
         | 
| 120 | 
            +
                             activation='quick_gelu',
         | 
| 121 | 
            +
                             attn_dropout=0.0,
         | 
| 122 | 
            +
                             proj_dropout=0.0,
         | 
| 123 | 
            +
                             norm_eps=1e-5):
         | 
| 124 | 
            +
                    assert activation in ['quick_gelu', 'gelu', 'swi_glu']
         | 
| 125 | 
            +
                    super().__init__()
         | 
| 126 | 
            +
                    self.dim = dim
         | 
| 127 | 
            +
                    self.mlp_ratio = mlp_ratio
         | 
| 128 | 
            +
                    self.num_heads = num_heads
         | 
| 129 | 
            +
                    self.post_norm = post_norm
         | 
| 130 | 
            +
                    self.causal = causal
         | 
| 131 | 
            +
                    self.norm_eps = norm_eps
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                    # layers
         | 
| 134 | 
            +
                    self.norm1 = LayerNorm(dim, eps=norm_eps)
         | 
| 135 | 
            +
                    self.attn = SelfAttention(dim, num_heads, causal, attn_dropout,
         | 
| 136 | 
            +
                                              proj_dropout)
         | 
| 137 | 
            +
                    self.norm2 = LayerNorm(dim, eps=norm_eps)
         | 
| 138 | 
            +
                    if activation == 'swi_glu':
         | 
| 139 | 
            +
                        self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
         | 
| 140 | 
            +
                    else:
         | 
| 141 | 
            +
                        self.mlp = nn.Sequential(
         | 
| 142 | 
            +
                            nn.Linear(dim, int(dim * mlp_ratio)),
         | 
| 143 | 
            +
                            QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
         | 
| 144 | 
            +
                            nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                def forward(self, x):
         | 
| 147 | 
            +
                    if self.post_norm:
         | 
| 148 | 
            +
                        x = x + self.norm1(self.attn(x))
         | 
| 149 | 
            +
                        x = x + self.norm2(self.mlp(x))
         | 
| 150 | 
            +
                    else:
         | 
| 151 | 
            +
                        x = x + self.attn(self.norm1(x))
         | 
| 152 | 
            +
                        x = x + self.mlp(self.norm2(x))
         | 
| 153 | 
            +
                    return x
         | 
| 154 | 
            +
             | 
| 155 | 
            +
             | 
| 156 | 
            +
            class AttentionPool(nn.Module):
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                def __init__(self,
         | 
| 159 | 
            +
                             dim,
         | 
| 160 | 
            +
                             mlp_ratio,
         | 
| 161 | 
            +
                             num_heads,
         | 
| 162 | 
            +
                             activation='gelu',
         | 
| 163 | 
            +
                             proj_dropout=0.0,
         | 
| 164 | 
            +
                             norm_eps=1e-5):
         | 
| 165 | 
            +
                    assert dim % num_heads == 0
         | 
| 166 | 
            +
                    super().__init__()
         | 
| 167 | 
            +
                    self.dim = dim
         | 
| 168 | 
            +
                    self.mlp_ratio = mlp_ratio
         | 
| 169 | 
            +
                    self.num_heads = num_heads
         | 
| 170 | 
            +
                    self.head_dim = dim // num_heads
         | 
| 171 | 
            +
                    self.proj_dropout = proj_dropout
         | 
| 172 | 
            +
                    self.norm_eps = norm_eps
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                    # layers
         | 
| 175 | 
            +
                    gain = 1.0 / math.sqrt(dim)
         | 
| 176 | 
            +
                    self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
         | 
| 177 | 
            +
                    self.to_q = nn.Linear(dim, dim)
         | 
| 178 | 
            +
                    self.to_kv = nn.Linear(dim, dim * 2)
         | 
| 179 | 
            +
                    self.proj = nn.Linear(dim, dim)
         | 
| 180 | 
            +
                    self.norm = LayerNorm(dim, eps=norm_eps)
         | 
| 181 | 
            +
                    self.mlp = nn.Sequential(
         | 
| 182 | 
            +
                        nn.Linear(dim, int(dim * mlp_ratio)),
         | 
| 183 | 
            +
                        QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
         | 
| 184 | 
            +
                        nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                def forward(self, x):
         | 
| 187 | 
            +
                    """
         | 
| 188 | 
            +
                    x:  [B, L, C].
         | 
| 189 | 
            +
                    """
         | 
| 190 | 
            +
                    b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                    # compute query, key, value
         | 
| 193 | 
            +
                    q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
         | 
| 194 | 
            +
                    k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                    # compute attention
         | 
| 197 | 
            +
                    x = flash_attention(q, k, v, version=2)
         | 
| 198 | 
            +
                    x = x.reshape(b, 1, c)
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                    # output
         | 
| 201 | 
            +
                    x = self.proj(x)
         | 
| 202 | 
            +
                    x = F.dropout(x, self.proj_dropout, self.training)
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                    # mlp
         | 
| 205 | 
            +
                    x = x + self.mlp(self.norm(x))
         | 
| 206 | 
            +
                    return x[:, 0]
         | 
| 207 | 
            +
             | 
| 208 | 
            +
             | 
| 209 | 
            +
            class VisionTransformer(nn.Module):
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                def __init__(self,
         | 
| 212 | 
            +
                             image_size=224,
         | 
| 213 | 
            +
                             patch_size=16,
         | 
| 214 | 
            +
                             dim=768,
         | 
| 215 | 
            +
                             mlp_ratio=4,
         | 
| 216 | 
            +
                             out_dim=512,
         | 
| 217 | 
            +
                             num_heads=12,
         | 
| 218 | 
            +
                             num_layers=12,
         | 
| 219 | 
            +
                             pool_type='token',
         | 
| 220 | 
            +
                             pre_norm=True,
         | 
| 221 | 
            +
                             post_norm=False,
         | 
| 222 | 
            +
                             activation='quick_gelu',
         | 
| 223 | 
            +
                             attn_dropout=0.0,
         | 
| 224 | 
            +
                             proj_dropout=0.0,
         | 
| 225 | 
            +
                             embedding_dropout=0.0,
         | 
| 226 | 
            +
                             norm_eps=1e-5):
         | 
| 227 | 
            +
                    if image_size % patch_size != 0:
         | 
| 228 | 
            +
                        print(
         | 
| 229 | 
            +
                            '[WARNING] image_size is not divisible by patch_size',
         | 
| 230 | 
            +
                            flush=True)
         | 
| 231 | 
            +
                    assert pool_type in ('token', 'token_fc', 'attn_pool')
         | 
| 232 | 
            +
                    out_dim = out_dim or dim
         | 
| 233 | 
            +
                    super().__init__()
         | 
| 234 | 
            +
                    self.image_size = image_size
         | 
| 235 | 
            +
                    self.patch_size = patch_size
         | 
| 236 | 
            +
                    self.num_patches = (image_size // patch_size)**2
         | 
| 237 | 
            +
                    self.dim = dim
         | 
| 238 | 
            +
                    self.mlp_ratio = mlp_ratio
         | 
| 239 | 
            +
                    self.out_dim = out_dim
         | 
| 240 | 
            +
                    self.num_heads = num_heads
         | 
| 241 | 
            +
                    self.num_layers = num_layers
         | 
| 242 | 
            +
                    self.pool_type = pool_type
         | 
| 243 | 
            +
                    self.post_norm = post_norm
         | 
| 244 | 
            +
                    self.norm_eps = norm_eps
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                    # embeddings
         | 
| 247 | 
            +
                    gain = 1.0 / math.sqrt(dim)
         | 
| 248 | 
            +
                    self.patch_embedding = nn.Conv2d(
         | 
| 249 | 
            +
                        3,
         | 
| 250 | 
            +
                        dim,
         | 
| 251 | 
            +
                        kernel_size=patch_size,
         | 
| 252 | 
            +
                        stride=patch_size,
         | 
| 253 | 
            +
                        bias=not pre_norm)
         | 
| 254 | 
            +
                    if pool_type in ('token', 'token_fc'):
         | 
| 255 | 
            +
                        self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
         | 
| 256 | 
            +
                    self.pos_embedding = nn.Parameter(gain * torch.randn(
         | 
| 257 | 
            +
                        1, self.num_patches +
         | 
| 258 | 
            +
                        (1 if pool_type in ('token', 'token_fc') else 0), dim))
         | 
| 259 | 
            +
                    self.dropout = nn.Dropout(embedding_dropout)
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                    # transformer
         | 
| 262 | 
            +
                    self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
         | 
| 263 | 
            +
                    self.transformer = nn.Sequential(*[
         | 
| 264 | 
            +
                        AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False,
         | 
| 265 | 
            +
                                       activation, attn_dropout, proj_dropout, norm_eps)
         | 
| 266 | 
            +
                        for _ in range(num_layers)
         | 
| 267 | 
            +
                    ])
         | 
| 268 | 
            +
                    self.post_norm = LayerNorm(dim, eps=norm_eps)
         | 
| 269 | 
            +
             | 
| 270 | 
            +
                    # head
         | 
| 271 | 
            +
                    if pool_type == 'token':
         | 
| 272 | 
            +
                        self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
         | 
| 273 | 
            +
                    elif pool_type == 'token_fc':
         | 
| 274 | 
            +
                        self.head = nn.Linear(dim, out_dim)
         | 
| 275 | 
            +
                    elif pool_type == 'attn_pool':
         | 
| 276 | 
            +
                        self.head = AttentionPool(dim, mlp_ratio, num_heads, activation,
         | 
| 277 | 
            +
                                                  proj_dropout, norm_eps)
         | 
| 278 | 
            +
             | 
| 279 | 
            +
                def forward(self, x, interpolation=False, use_31_block=False):
         | 
| 280 | 
            +
                    b = x.size(0)
         | 
| 281 | 
            +
             | 
| 282 | 
            +
                    # embeddings
         | 
| 283 | 
            +
                    x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
         | 
| 284 | 
            +
                    if self.pool_type in ('token', 'token_fc'):
         | 
| 285 | 
            +
                        x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1)
         | 
| 286 | 
            +
                    if interpolation:
         | 
| 287 | 
            +
                        e = pos_interpolate(self.pos_embedding, x.size(1))
         | 
| 288 | 
            +
                    else:
         | 
| 289 | 
            +
                        e = self.pos_embedding
         | 
| 290 | 
            +
                    x = self.dropout(x + e)
         | 
| 291 | 
            +
                    if self.pre_norm is not None:
         | 
| 292 | 
            +
                        x = self.pre_norm(x)
         | 
| 293 | 
            +
             | 
| 294 | 
            +
                    # transformer
         | 
| 295 | 
            +
                    if use_31_block:
         | 
| 296 | 
            +
                        x = self.transformer[:-1](x)
         | 
| 297 | 
            +
                        return x
         | 
| 298 | 
            +
                    else:
         | 
| 299 | 
            +
                        x = self.transformer(x)
         | 
| 300 | 
            +
                        return x
         | 
| 301 | 
            +
             | 
| 302 | 
            +
             | 
| 303 | 
            +
            class XLMRobertaWithHead(XLMRoberta):
         | 
| 304 | 
            +
             | 
| 305 | 
            +
                def __init__(self, **kwargs):
         | 
| 306 | 
            +
                    self.out_dim = kwargs.pop('out_dim')
         | 
| 307 | 
            +
                    super().__init__(**kwargs)
         | 
| 308 | 
            +
             | 
| 309 | 
            +
                    # head
         | 
| 310 | 
            +
                    mid_dim = (self.dim + self.out_dim) // 2
         | 
| 311 | 
            +
                    self.head = nn.Sequential(
         | 
| 312 | 
            +
                        nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(),
         | 
| 313 | 
            +
                        nn.Linear(mid_dim, self.out_dim, bias=False))
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                def forward(self, ids):
         | 
| 316 | 
            +
                    # xlm-roberta
         | 
| 317 | 
            +
                    x = super().forward(ids)
         | 
| 318 | 
            +
             | 
| 319 | 
            +
                    # average pooling
         | 
| 320 | 
            +
                    mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
         | 
| 321 | 
            +
                    x = (x * mask).sum(dim=1) / mask.sum(dim=1)
         | 
| 322 | 
            +
             | 
| 323 | 
            +
                    # head
         | 
| 324 | 
            +
                    x = self.head(x)
         | 
| 325 | 
            +
                    return x
         | 
| 326 | 
            +
             | 
| 327 | 
            +
             | 
| 328 | 
            +
            class XLMRobertaCLIP(nn.Module):
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                def __init__(self,
         | 
| 331 | 
            +
                             embed_dim=1024,
         | 
| 332 | 
            +
                             image_size=224,
         | 
| 333 | 
            +
                             patch_size=14,
         | 
| 334 | 
            +
                             vision_dim=1280,
         | 
| 335 | 
            +
                             vision_mlp_ratio=4,
         | 
| 336 | 
            +
                             vision_heads=16,
         | 
| 337 | 
            +
                             vision_layers=32,
         | 
| 338 | 
            +
                             vision_pool='token',
         | 
| 339 | 
            +
                             vision_pre_norm=True,
         | 
| 340 | 
            +
                             vision_post_norm=False,
         | 
| 341 | 
            +
                             activation='gelu',
         | 
| 342 | 
            +
                             vocab_size=250002,
         | 
| 343 | 
            +
                             max_text_len=514,
         | 
| 344 | 
            +
                             type_size=1,
         | 
| 345 | 
            +
                             pad_id=1,
         | 
| 346 | 
            +
                             text_dim=1024,
         | 
| 347 | 
            +
                             text_heads=16,
         | 
| 348 | 
            +
                             text_layers=24,
         | 
| 349 | 
            +
                             text_post_norm=True,
         | 
| 350 | 
            +
                             text_dropout=0.1,
         | 
| 351 | 
            +
                             attn_dropout=0.0,
         | 
| 352 | 
            +
                             proj_dropout=0.0,
         | 
| 353 | 
            +
                             embedding_dropout=0.0,
         | 
| 354 | 
            +
                             norm_eps=1e-5):
         | 
| 355 | 
            +
                    super().__init__()
         | 
| 356 | 
            +
                    self.embed_dim = embed_dim
         | 
| 357 | 
            +
                    self.image_size = image_size
         | 
| 358 | 
            +
                    self.patch_size = patch_size
         | 
| 359 | 
            +
                    self.vision_dim = vision_dim
         | 
| 360 | 
            +
                    self.vision_mlp_ratio = vision_mlp_ratio
         | 
| 361 | 
            +
                    self.vision_heads = vision_heads
         | 
| 362 | 
            +
                    self.vision_layers = vision_layers
         | 
| 363 | 
            +
                    self.vision_pre_norm = vision_pre_norm
         | 
| 364 | 
            +
                    self.vision_post_norm = vision_post_norm
         | 
| 365 | 
            +
                    self.activation = activation
         | 
| 366 | 
            +
                    self.vocab_size = vocab_size
         | 
| 367 | 
            +
                    self.max_text_len = max_text_len
         | 
| 368 | 
            +
                    self.type_size = type_size
         | 
| 369 | 
            +
                    self.pad_id = pad_id
         | 
| 370 | 
            +
                    self.text_dim = text_dim
         | 
| 371 | 
            +
                    self.text_heads = text_heads
         | 
| 372 | 
            +
                    self.text_layers = text_layers
         | 
| 373 | 
            +
                    self.text_post_norm = text_post_norm
         | 
| 374 | 
            +
                    self.norm_eps = norm_eps
         | 
| 375 | 
            +
             | 
| 376 | 
            +
                    # models
         | 
| 377 | 
            +
                    self.visual = VisionTransformer(
         | 
| 378 | 
            +
                        image_size=image_size,
         | 
| 379 | 
            +
                        patch_size=patch_size,
         | 
| 380 | 
            +
                        dim=vision_dim,
         | 
| 381 | 
            +
                        mlp_ratio=vision_mlp_ratio,
         | 
| 382 | 
            +
                        out_dim=embed_dim,
         | 
| 383 | 
            +
                        num_heads=vision_heads,
         | 
| 384 | 
            +
                        num_layers=vision_layers,
         | 
| 385 | 
            +
                        pool_type=vision_pool,
         | 
| 386 | 
            +
                        pre_norm=vision_pre_norm,
         | 
| 387 | 
            +
                        post_norm=vision_post_norm,
         | 
| 388 | 
            +
                        activation=activation,
         | 
| 389 | 
            +
                        attn_dropout=attn_dropout,
         | 
| 390 | 
            +
                        proj_dropout=proj_dropout,
         | 
| 391 | 
            +
                        embedding_dropout=embedding_dropout,
         | 
| 392 | 
            +
                        norm_eps=norm_eps)
         | 
| 393 | 
            +
                    self.textual = XLMRobertaWithHead(
         | 
| 394 | 
            +
                        vocab_size=vocab_size,
         | 
| 395 | 
            +
                        max_seq_len=max_text_len,
         | 
| 396 | 
            +
                        type_size=type_size,
         | 
| 397 | 
            +
                        pad_id=pad_id,
         | 
| 398 | 
            +
                        dim=text_dim,
         | 
| 399 | 
            +
                        out_dim=embed_dim,
         | 
| 400 | 
            +
                        num_heads=text_heads,
         | 
| 401 | 
            +
                        num_layers=text_layers,
         | 
| 402 | 
            +
                        post_norm=text_post_norm,
         | 
| 403 | 
            +
                        dropout=text_dropout)
         | 
| 404 | 
            +
                    self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
         | 
| 405 | 
            +
             | 
| 406 | 
            +
                def forward(self, imgs, txt_ids):
         | 
| 407 | 
            +
                    """
         | 
| 408 | 
            +
                    imgs:       [B, 3, H, W] of torch.float32.
         | 
| 409 | 
            +
                    - mean:     [0.48145466, 0.4578275, 0.40821073]
         | 
| 410 | 
            +
                    - std:      [0.26862954, 0.26130258, 0.27577711]
         | 
| 411 | 
            +
                    txt_ids:    [B, L] of torch.long.
         | 
| 412 | 
            +
                                Encoded by data.CLIPTokenizer.
         | 
| 413 | 
            +
                    """
         | 
| 414 | 
            +
                    xi = self.visual(imgs)
         | 
| 415 | 
            +
                    xt = self.textual(txt_ids)
         | 
| 416 | 
            +
                    return xi, xt
         | 
| 417 | 
            +
             | 
| 418 | 
            +
                def param_groups(self):
         | 
| 419 | 
            +
                    groups = [{
         | 
| 420 | 
            +
                        'params': [
         | 
| 421 | 
            +
                            p for n, p in self.named_parameters()
         | 
| 422 | 
            +
                            if 'norm' in n or n.endswith('bias')
         | 
| 423 | 
            +
                        ],
         | 
| 424 | 
            +
                        'weight_decay': 0.0
         | 
| 425 | 
            +
                    }, {
         | 
| 426 | 
            +
                        'params': [
         | 
| 427 | 
            +
                            p for n, p in self.named_parameters()
         | 
| 428 | 
            +
                            if not ('norm' in n or n.endswith('bias'))
         | 
| 429 | 
            +
                        ]
         | 
| 430 | 
            +
                    }]
         | 
| 431 | 
            +
                    return groups
         | 
| 432 | 
            +
             | 
| 433 | 
            +
             | 
| 434 | 
            +
            def _clip(pretrained=False,
         | 
| 435 | 
            +
                      pretrained_name=None,
         | 
| 436 | 
            +
                      model_cls=XLMRobertaCLIP,
         | 
| 437 | 
            +
                      return_transforms=False,
         | 
| 438 | 
            +
                      return_tokenizer=False,
         | 
| 439 | 
            +
                      tokenizer_padding='eos',
         | 
| 440 | 
            +
                      dtype=torch.float32,
         | 
| 441 | 
            +
                      device='cpu',
         | 
| 442 | 
            +
                      **kwargs):
         | 
| 443 | 
            +
                # init a model on device
         | 
| 444 | 
            +
                with torch.device(device):
         | 
| 445 | 
            +
                    model = model_cls(**kwargs)
         | 
| 446 | 
            +
             | 
| 447 | 
            +
                # set device
         | 
| 448 | 
            +
                model = model.to(dtype=dtype, device=device)
         | 
| 449 | 
            +
                output = (model,)
         | 
| 450 | 
            +
             | 
| 451 | 
            +
                # init transforms
         | 
| 452 | 
            +
                if return_transforms:
         | 
| 453 | 
            +
                    # mean and std
         | 
| 454 | 
            +
                    if 'siglip' in pretrained_name.lower():
         | 
| 455 | 
            +
                        mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
         | 
| 456 | 
            +
                    else:
         | 
| 457 | 
            +
                        mean = [0.48145466, 0.4578275, 0.40821073]
         | 
| 458 | 
            +
                        std = [0.26862954, 0.26130258, 0.27577711]
         | 
| 459 | 
            +
             | 
| 460 | 
            +
                    # transforms
         | 
| 461 | 
            +
                    transforms = T.Compose([
         | 
| 462 | 
            +
                        T.Resize((model.image_size, model.image_size),
         | 
| 463 | 
            +
                                 interpolation=T.InterpolationMode.BICUBIC),
         | 
| 464 | 
            +
                        T.ToTensor(),
         | 
| 465 | 
            +
                        T.Normalize(mean=mean, std=std)
         | 
| 466 | 
            +
                    ])
         | 
| 467 | 
            +
                    output += (transforms,)
         | 
| 468 | 
            +
                return output[0] if len(output) == 1 else output
         | 
| 469 | 
            +
             | 
| 470 | 
            +
             | 
| 471 | 
            +
            def clip_xlm_roberta_vit_h_14(
         | 
| 472 | 
            +
                    pretrained=False,
         | 
| 473 | 
            +
                    pretrained_name='open-clip-xlm-roberta-large-vit-huge-14',
         | 
| 474 | 
            +
                    **kwargs):
         | 
| 475 | 
            +
                cfg = dict(
         | 
| 476 | 
            +
                    embed_dim=1024,
         | 
| 477 | 
            +
                    image_size=224,
         | 
| 478 | 
            +
                    patch_size=14,
         | 
| 479 | 
            +
                    vision_dim=1280,
         | 
| 480 | 
            +
                    vision_mlp_ratio=4,
         | 
| 481 | 
            +
                    vision_heads=16,
         | 
| 482 | 
            +
                    vision_layers=32,
         | 
| 483 | 
            +
                    vision_pool='token',
         | 
| 484 | 
            +
                    activation='gelu',
         | 
| 485 | 
            +
                    vocab_size=250002,
         | 
| 486 | 
            +
                    max_text_len=514,
         | 
| 487 | 
            +
                    type_size=1,
         | 
| 488 | 
            +
                    pad_id=1,
         | 
| 489 | 
            +
                    text_dim=1024,
         | 
| 490 | 
            +
                    text_heads=16,
         | 
| 491 | 
            +
                    text_layers=24,
         | 
| 492 | 
            +
                    text_post_norm=True,
         | 
| 493 | 
            +
                    text_dropout=0.1,
         | 
| 494 | 
            +
                    attn_dropout=0.0,
         | 
| 495 | 
            +
                    proj_dropout=0.0,
         | 
| 496 | 
            +
                    embedding_dropout=0.0)
         | 
| 497 | 
            +
                cfg.update(**kwargs)
         | 
| 498 | 
            +
                return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)
         | 
| 499 | 
            +
             | 
| 500 | 
            +
             | 
| 501 | 
            +
            class CLIPModel:
         | 
| 502 | 
            +
             | 
| 503 | 
            +
                def __init__(self, dtype, device, checkpoint_path, tokenizer_path):
         | 
| 504 | 
            +
                    self.dtype = dtype
         | 
| 505 | 
            +
                    self.device = device
         | 
| 506 | 
            +
                    self.checkpoint_path = checkpoint_path
         | 
| 507 | 
            +
                    self.tokenizer_path = tokenizer_path
         | 
| 508 | 
            +
             | 
| 509 | 
            +
                    # init model
         | 
| 510 | 
            +
                    self.model, self.transforms = clip_xlm_roberta_vit_h_14(
         | 
| 511 | 
            +
                        pretrained=False,
         | 
| 512 | 
            +
                        return_transforms=True,
         | 
| 513 | 
            +
                        return_tokenizer=False,
         | 
| 514 | 
            +
                        dtype=dtype,
         | 
| 515 | 
            +
                        device=device)
         | 
| 516 | 
            +
                    self.model = self.model.eval().requires_grad_(False)
         | 
| 517 | 
            +
                    logging.info(f'loading {checkpoint_path}')
         | 
| 518 | 
            +
                    self.model.load_state_dict(
         | 
| 519 | 
            +
                        torch.load(checkpoint_path, map_location='cpu'))
         | 
| 520 | 
            +
             | 
| 521 | 
            +
                    # init tokenizer
         | 
| 522 | 
            +
                    self.tokenizer = HuggingfaceTokenizer(
         | 
| 523 | 
            +
                        name=tokenizer_path,
         | 
| 524 | 
            +
                        seq_len=self.model.max_text_len - 2,
         | 
| 525 | 
            +
                        clean='whitespace')
         | 
| 526 | 
            +
             | 
| 527 | 
            +
                def visual(self, videos):
         | 
| 528 | 
            +
                    # preprocess
         | 
| 529 | 
            +
                    size = (self.model.image_size,) * 2
         | 
| 530 | 
            +
                    videos = torch.cat([
         | 
| 531 | 
            +
                        F.interpolate(
         | 
| 532 | 
            +
                            u.transpose(0, 1),
         | 
| 533 | 
            +
                            size=size,
         | 
| 534 | 
            +
                            mode='bicubic',
         | 
| 535 | 
            +
                            align_corners=False) for u in videos
         | 
| 536 | 
            +
                    ])
         | 
| 537 | 
            +
                    videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
         | 
| 538 | 
            +
             | 
| 539 | 
            +
                    # forward
         | 
| 540 | 
            +
                    with torch.amp.autocast('cuda', dtype=self.dtype):
         | 
| 541 | 
            +
                        out = self.model.visual(videos, use_31_block=True)
         | 
| 542 | 
            +
                        return out
         | 
    	
        humo/models/wan_modules/model.py
    ADDED
    
    | @@ -0,0 +1,619 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
         | 
| 2 | 
            +
            import math
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            import torch.cuda.amp as amp
         | 
| 6 | 
            +
            import torch.nn as nn
         | 
| 7 | 
            +
            from diffusers.configuration_utils import ConfigMixin, register_to_config
         | 
| 8 | 
            +
            from diffusers.models.modeling_utils import ModelMixin
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from .attention import flash_attention
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            __all__ = ['WanModel']
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            def sinusoidal_embedding_1d(dim, position):
         | 
| 16 | 
            +
                # preprocess
         | 
| 17 | 
            +
                assert dim % 2 == 0
         | 
| 18 | 
            +
                half = dim // 2
         | 
| 19 | 
            +
                position = position.type(torch.float64)
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                # calculation
         | 
| 22 | 
            +
                sinusoid = torch.outer(
         | 
| 23 | 
            +
                    position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
         | 
| 24 | 
            +
                x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
         | 
| 25 | 
            +
                return x
         | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
            @torch.amp.autocast("cuda", enabled=False)
         | 
| 29 | 
            +
            def rope_params(max_seq_len, dim, theta=10000):
         | 
| 30 | 
            +
                assert dim % 2 == 0
         | 
| 31 | 
            +
                freqs = torch.outer(
         | 
| 32 | 
            +
                    torch.arange(max_seq_len),
         | 
| 33 | 
            +
                    1.0 / torch.pow(theta,
         | 
| 34 | 
            +
                                    torch.arange(0, dim, 2).to(torch.float64).div(dim)))
         | 
| 35 | 
            +
                freqs = torch.polar(torch.ones_like(freqs), freqs)
         | 
| 36 | 
            +
                return freqs
         | 
| 37 | 
            +
             | 
| 38 | 
            +
             | 
| 39 | 
            +
            @torch.amp.autocast("cuda", enabled=False)
         | 
| 40 | 
            +
            def rope_apply(x, grid_sizes, freqs):
         | 
| 41 | 
            +
                n, c = x.size(2), x.size(3) // 2
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                # split freqs
         | 
| 44 | 
            +
                freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                # loop over samples
         | 
| 47 | 
            +
                output = []
         | 
| 48 | 
            +
                for i, (f, h, w) in enumerate(grid_sizes.tolist()):
         | 
| 49 | 
            +
                    seq_len = f * h * w
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                    # precompute multipliers
         | 
| 52 | 
            +
                    x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(
         | 
| 53 | 
            +
                        seq_len, n, -1, 2))
         | 
| 54 | 
            +
                    freqs_i = torch.cat([
         | 
| 55 | 
            +
                        freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
         | 
| 56 | 
            +
                        freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
         | 
| 57 | 
            +
                        freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
         | 
| 58 | 
            +
                    ],
         | 
| 59 | 
            +
                                        dim=-1).reshape(seq_len, 1, -1)
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                    # apply rotary embedding
         | 
| 62 | 
            +
                    x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
         | 
| 63 | 
            +
                    x_i = torch.cat([x_i, x[i, seq_len:]])
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    # append to collection
         | 
| 66 | 
            +
                    output.append(x_i)
         | 
| 67 | 
            +
                return torch.stack(output).float()
         | 
| 68 | 
            +
             | 
| 69 | 
            +
             | 
| 70 | 
            +
            class WanRMSNorm(nn.Module):
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                def __init__(self, dim, eps=1e-5):
         | 
| 73 | 
            +
                    super().__init__()
         | 
| 74 | 
            +
                    self.dim = dim
         | 
| 75 | 
            +
                    self.eps = eps
         | 
| 76 | 
            +
                    self.weight = nn.Parameter(torch.ones(dim))
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                def forward(self, x):
         | 
| 79 | 
            +
                    r"""
         | 
| 80 | 
            +
                    Args:
         | 
| 81 | 
            +
                        x(Tensor): Shape [B, L, C]
         | 
| 82 | 
            +
                    """
         | 
| 83 | 
            +
                    return self._norm(x.float()).type_as(x) * self.weight
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                def _norm(self, x):
         | 
| 86 | 
            +
                    return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
         | 
| 87 | 
            +
             | 
| 88 | 
            +
             | 
| 89 | 
            +
            class WanLayerNorm(nn.LayerNorm):
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                def __init__(self, dim, eps=1e-6, elementwise_affine=False):
         | 
| 92 | 
            +
                    super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                def forward(self, x):
         | 
| 95 | 
            +
                    r"""
         | 
| 96 | 
            +
                    Args:
         | 
| 97 | 
            +
                        x(Tensor): Shape [B, L, C]
         | 
| 98 | 
            +
                    """
         | 
| 99 | 
            +
                    return super().forward(x.float()).type_as(x)
         | 
| 100 | 
            +
             | 
| 101 | 
            +
             | 
| 102 | 
            +
            class WanSelfAttention(nn.Module):
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                def __init__(self,
         | 
| 105 | 
            +
                             dim,
         | 
| 106 | 
            +
                             num_heads,
         | 
| 107 | 
            +
                             window_size=(-1, -1),
         | 
| 108 | 
            +
                             qk_norm=True,
         | 
| 109 | 
            +
                             eps=1e-6):
         | 
| 110 | 
            +
                    assert dim % num_heads == 0
         | 
| 111 | 
            +
                    super().__init__()
         | 
| 112 | 
            +
                    self.dim = dim
         | 
| 113 | 
            +
                    self.num_heads = num_heads
         | 
| 114 | 
            +
                    self.head_dim = dim // num_heads
         | 
| 115 | 
            +
                    self.window_size = window_size
         | 
| 116 | 
            +
                    self.qk_norm = qk_norm
         | 
| 117 | 
            +
                    self.eps = eps
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                    # layers
         | 
| 120 | 
            +
                    self.q = nn.Linear(dim, dim)
         | 
| 121 | 
            +
                    self.k = nn.Linear(dim, dim)
         | 
| 122 | 
            +
                    self.v = nn.Linear(dim, dim)
         | 
| 123 | 
            +
                    self.o = nn.Linear(dim, dim)
         | 
| 124 | 
            +
                    self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
         | 
| 125 | 
            +
                    self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                def forward(self, x, seq_lens, grid_sizes, freqs):
         | 
| 128 | 
            +
                    r"""
         | 
| 129 | 
            +
                    Args:
         | 
| 130 | 
            +
                        x(Tensor): Shape [B, L, num_heads, C / num_heads]
         | 
| 131 | 
            +
                        seq_lens(Tensor): Shape [B]
         | 
| 132 | 
            +
                        grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
         | 
| 133 | 
            +
                        freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
         | 
| 134 | 
            +
                    """
         | 
| 135 | 
            +
                    b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                    # query, key, value function
         | 
| 138 | 
            +
                    def qkv_fn(x):
         | 
| 139 | 
            +
                        q = self.norm_q(self.q(x)).view(b, s, n, d)
         | 
| 140 | 
            +
                        k = self.norm_k(self.k(x)).view(b, s, n, d)
         | 
| 141 | 
            +
                        v = self.v(x).view(b, s, n, d)
         | 
| 142 | 
            +
                        return q, k, v
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                    q, k, v = qkv_fn(x)
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                    x = flash_attention(
         | 
| 147 | 
            +
                        q=rope_apply(q, grid_sizes, freqs),
         | 
| 148 | 
            +
                        k=rope_apply(k, grid_sizes, freqs),
         | 
| 149 | 
            +
                        v=v,
         | 
| 150 | 
            +
                        k_lens=seq_lens,
         | 
| 151 | 
            +
                        window_size=self.window_size)
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                    # output
         | 
| 154 | 
            +
                    x = x.flatten(2)
         | 
| 155 | 
            +
                    x = self.o(x)
         | 
| 156 | 
            +
                    return x
         | 
| 157 | 
            +
             | 
| 158 | 
            +
             | 
| 159 | 
            +
            class WanT2VCrossAttention(WanSelfAttention):
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                def forward(self, x, context, context_lens):
         | 
| 162 | 
            +
                    r"""
         | 
| 163 | 
            +
                    Args:
         | 
| 164 | 
            +
                        x(Tensor): Shape [B, L1, C]
         | 
| 165 | 
            +
                        context(Tensor): Shape [B, L2, C]
         | 
| 166 | 
            +
                        context_lens(Tensor): Shape [B]
         | 
| 167 | 
            +
                    """
         | 
| 168 | 
            +
                    b, n, d = x.size(0), self.num_heads, self.head_dim
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                    # compute query, key, value
         | 
| 171 | 
            +
                    q = self.norm_q(self.q(x)).view(b, -1, n, d)
         | 
| 172 | 
            +
                    k = self.norm_k(self.k(context)).view(b, -1, n, d)
         | 
| 173 | 
            +
                    v = self.v(context).view(b, -1, n, d)
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                    # compute attention
         | 
| 176 | 
            +
                    x = flash_attention(q, k, v, k_lens=context_lens)
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                    # output
         | 
| 179 | 
            +
                    x = x.flatten(2)
         | 
| 180 | 
            +
                    x = self.o(x)
         | 
| 181 | 
            +
                    return x
         | 
| 182 | 
            +
             | 
| 183 | 
            +
             | 
| 184 | 
            +
            class WanI2VCrossAttention(WanSelfAttention):
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                def __init__(self,
         | 
| 187 | 
            +
                             dim,
         | 
| 188 | 
            +
                             num_heads,
         | 
| 189 | 
            +
                             window_size=(-1, -1),
         | 
| 190 | 
            +
                             qk_norm=True,
         | 
| 191 | 
            +
                             eps=1e-6):
         | 
| 192 | 
            +
                    super().__init__(dim, num_heads, window_size, qk_norm, eps)
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                    self.k_img = nn.Linear(dim, dim)
         | 
| 195 | 
            +
                    self.v_img = nn.Linear(dim, dim)
         | 
| 196 | 
            +
                    # self.alpha = nn.Parameter(torch.zeros((1, )))
         | 
| 197 | 
            +
                    self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                def forward(self, x, context, context_lens):
         | 
| 200 | 
            +
                    r"""
         | 
| 201 | 
            +
                    Args:
         | 
| 202 | 
            +
                        x(Tensor): Shape [B, L1, C]
         | 
| 203 | 
            +
                        context(Tensor): Shape [B, L2, C]
         | 
| 204 | 
            +
                        context_lens(Tensor): Shape [B]
         | 
| 205 | 
            +
                    """
         | 
| 206 | 
            +
                    context_img = context[:, :257]
         | 
| 207 | 
            +
                    context = context[:, 257:]
         | 
| 208 | 
            +
                    b, n, d = x.size(0), self.num_heads, self.head_dim
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                    # compute query, key, value
         | 
| 211 | 
            +
                    q = self.norm_q(self.q(x)).view(b, -1, n, d)
         | 
| 212 | 
            +
                    k = self.norm_k(self.k(context)).view(b, -1, n, d)
         | 
| 213 | 
            +
                    v = self.v(context).view(b, -1, n, d)
         | 
| 214 | 
            +
                    k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
         | 
| 215 | 
            +
                    v_img = self.v_img(context_img).view(b, -1, n, d)
         | 
| 216 | 
            +
                    img_x = flash_attention(q, k_img, v_img, k_lens=None)
         | 
| 217 | 
            +
                    # compute attention
         | 
| 218 | 
            +
                    x = flash_attention(q, k, v, k_lens=context_lens)
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                    # output
         | 
| 221 | 
            +
                    x = x.flatten(2)
         | 
| 222 | 
            +
                    img_x = img_x.flatten(2)
         | 
| 223 | 
            +
                    x = x + img_x
         | 
| 224 | 
            +
                    x = self.o(x)
         | 
| 225 | 
            +
                    return x
         | 
| 226 | 
            +
             | 
| 227 | 
            +
             | 
| 228 | 
            +
            WAN_CROSSATTENTION_CLASSES = {
         | 
| 229 | 
            +
                't2v_cross_attn': WanT2VCrossAttention,
         | 
| 230 | 
            +
                'i2v_cross_attn': WanI2VCrossAttention,
         | 
| 231 | 
            +
            }
         | 
| 232 | 
            +
             | 
| 233 | 
            +
             | 
| 234 | 
            +
            class WanAttentionBlock(nn.Module):
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                def __init__(self,
         | 
| 237 | 
            +
                             cross_attn_type,
         | 
| 238 | 
            +
                             dim,
         | 
| 239 | 
            +
                             ffn_dim,
         | 
| 240 | 
            +
                             num_heads,
         | 
| 241 | 
            +
                             window_size=(-1, -1),
         | 
| 242 | 
            +
                             qk_norm=True,
         | 
| 243 | 
            +
                             cross_attn_norm=False,
         | 
| 244 | 
            +
                             eps=1e-6):
         | 
| 245 | 
            +
                    super().__init__()
         | 
| 246 | 
            +
                    self.dim = dim
         | 
| 247 | 
            +
                    self.ffn_dim = ffn_dim
         | 
| 248 | 
            +
                    self.num_heads = num_heads
         | 
| 249 | 
            +
                    self.window_size = window_size
         | 
| 250 | 
            +
                    self.qk_norm = qk_norm
         | 
| 251 | 
            +
                    self.cross_attn_norm = cross_attn_norm
         | 
| 252 | 
            +
                    self.eps = eps
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                    # layers
         | 
| 255 | 
            +
                    self.norm1 = WanLayerNorm(dim, eps)
         | 
| 256 | 
            +
                    self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
         | 
| 257 | 
            +
                                                      eps)
         | 
| 258 | 
            +
                    self.norm3 = WanLayerNorm(
         | 
| 259 | 
            +
                        dim, eps,
         | 
| 260 | 
            +
                        elementwise_affine=True) if cross_attn_norm else nn.Identity()
         | 
| 261 | 
            +
                    self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim,
         | 
| 262 | 
            +
                                                                                  num_heads,
         | 
| 263 | 
            +
                                                                                  (-1, -1),
         | 
| 264 | 
            +
                                                                                  qk_norm,
         | 
| 265 | 
            +
                                                                                  eps)
         | 
| 266 | 
            +
                    self.norm2 = WanLayerNorm(dim, eps)
         | 
| 267 | 
            +
                    self.ffn = nn.Sequential(
         | 
| 268 | 
            +
                        nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
         | 
| 269 | 
            +
                        nn.Linear(ffn_dim, dim))
         | 
| 270 | 
            +
             | 
| 271 | 
            +
                    # modulation
         | 
| 272 | 
            +
                    self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
         | 
| 273 | 
            +
             | 
| 274 | 
            +
                def forward(
         | 
| 275 | 
            +
                    self,
         | 
| 276 | 
            +
                    x,
         | 
| 277 | 
            +
                    e,
         | 
| 278 | 
            +
                    seq_lens,
         | 
| 279 | 
            +
                    grid_sizes,
         | 
| 280 | 
            +
                    freqs,
         | 
| 281 | 
            +
                    context,
         | 
| 282 | 
            +
                    context_lens,
         | 
| 283 | 
            +
                ):
         | 
| 284 | 
            +
                    r"""
         | 
| 285 | 
            +
                    Args:
         | 
| 286 | 
            +
                        x(Tensor): Shape [B, L, C]
         | 
| 287 | 
            +
                        e(Tensor): Shape [B, 6, C]
         | 
| 288 | 
            +
                        seq_lens(Tensor): Shape [B], length of each sequence in batch
         | 
| 289 | 
            +
                        grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
         | 
| 290 | 
            +
                        freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
         | 
| 291 | 
            +
                    """
         | 
| 292 | 
            +
                    assert e.dtype == torch.float32
         | 
| 293 | 
            +
                    with torch.amp.autocast('cuda', dtype=torch.float32):
         | 
| 294 | 
            +
                        e = (self.modulation + e).chunk(6, dim=1)
         | 
| 295 | 
            +
                    assert e[0].dtype == torch.float32
         | 
| 296 | 
            +
             | 
| 297 | 
            +
                    # self-attention
         | 
| 298 | 
            +
                    y = self.self_attn(
         | 
| 299 | 
            +
                        self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes,
         | 
| 300 | 
            +
                        freqs)
         | 
| 301 | 
            +
                    with torch.amp.autocast('cuda', dtype=torch.float32):
         | 
| 302 | 
            +
                        x = x + y * e[2]
         | 
| 303 | 
            +
             | 
| 304 | 
            +
                    # cross-attention & ffn function
         | 
| 305 | 
            +
                    def cross_attn_ffn(x, context, context_lens, e):
         | 
| 306 | 
            +
                        x = x + self.cross_attn(self.norm3(x), context, context_lens)
         | 
| 307 | 
            +
                        y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
         | 
| 308 | 
            +
                        with torch.amp.autocast('cuda', dtype=torch.float32):
         | 
| 309 | 
            +
                            x = x + y * e[5]
         | 
| 310 | 
            +
                        return x
         | 
| 311 | 
            +
             | 
| 312 | 
            +
                    x = cross_attn_ffn(x, context, context_lens, e)
         | 
| 313 | 
            +
                    return x
         | 
| 314 | 
            +
             | 
| 315 | 
            +
             | 
| 316 | 
            +
            class Head(nn.Module):
         | 
| 317 | 
            +
             | 
| 318 | 
            +
                def __init__(self, dim, out_dim, patch_size, eps=1e-6):
         | 
| 319 | 
            +
                    super().__init__()
         | 
| 320 | 
            +
                    self.dim = dim
         | 
| 321 | 
            +
                    self.out_dim = out_dim
         | 
| 322 | 
            +
                    self.patch_size = patch_size
         | 
| 323 | 
            +
                    self.eps = eps
         | 
| 324 | 
            +
             | 
| 325 | 
            +
                    # layers
         | 
| 326 | 
            +
                    out_dim = math.prod(patch_size) * out_dim
         | 
| 327 | 
            +
                    self.norm = WanLayerNorm(dim, eps)
         | 
| 328 | 
            +
                    self.head = nn.Linear(dim, out_dim)
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                    # modulation
         | 
| 331 | 
            +
                    self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
         | 
| 332 | 
            +
             | 
| 333 | 
            +
                def forward(self, x, e):
         | 
| 334 | 
            +
                    r"""
         | 
| 335 | 
            +
                    Args:
         | 
| 336 | 
            +
                        x(Tensor): Shape [B, L1, C]
         | 
| 337 | 
            +
                        e(Tensor): Shape [B, C]
         | 
| 338 | 
            +
                    """
         | 
| 339 | 
            +
                    assert e.dtype == torch.float32
         | 
| 340 | 
            +
                    with torch.amp.autocast('cuda', dtype=torch.float32):
         | 
| 341 | 
            +
                        e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
         | 
| 342 | 
            +
                        x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
         | 
| 343 | 
            +
                    return x
         | 
| 344 | 
            +
             | 
| 345 | 
            +
             | 
| 346 | 
            +
            class MLPProj(torch.nn.Module):
         | 
| 347 | 
            +
             | 
| 348 | 
            +
                def __init__(self, in_dim, out_dim):
         | 
| 349 | 
            +
                    super().__init__()
         | 
| 350 | 
            +
             | 
| 351 | 
            +
                    self.proj = torch.nn.Sequential(
         | 
| 352 | 
            +
                        torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim),
         | 
| 353 | 
            +
                        torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim),
         | 
| 354 | 
            +
                        torch.nn.LayerNorm(out_dim))
         | 
| 355 | 
            +
             | 
| 356 | 
            +
                def forward(self, image_embeds):
         | 
| 357 | 
            +
                    clip_extra_context_tokens = self.proj(image_embeds)
         | 
| 358 | 
            +
                    return clip_extra_context_tokens
         | 
| 359 | 
            +
             | 
| 360 | 
            +
             | 
| 361 | 
            +
            class WanModel(ModelMixin, ConfigMixin):
         | 
| 362 | 
            +
                r"""
         | 
| 363 | 
            +
                Wan diffusion backbone supporting both text-to-video and image-to-video.
         | 
| 364 | 
            +
                """
         | 
| 365 | 
            +
             | 
| 366 | 
            +
                ignore_for_config = [
         | 
| 367 | 
            +
                    'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size'
         | 
| 368 | 
            +
                ]
         | 
| 369 | 
            +
                _no_split_modules = ['WanAttentionBlock']
         | 
| 370 | 
            +
             | 
| 371 | 
            +
                @register_to_config
         | 
| 372 | 
            +
                def __init__(self,
         | 
| 373 | 
            +
                             model_type='t2v',
         | 
| 374 | 
            +
                             patch_size=(1, 2, 2),
         | 
| 375 | 
            +
                             text_len=512,
         | 
| 376 | 
            +
                             in_dim=16,
         | 
| 377 | 
            +
                             dim=5120,
         | 
| 378 | 
            +
                             ffn_dim=13824,
         | 
| 379 | 
            +
                             freq_dim=256,
         | 
| 380 | 
            +
                             text_dim=4096,
         | 
| 381 | 
            +
                             out_dim=16,
         | 
| 382 | 
            +
                             num_heads=40,
         | 
| 383 | 
            +
                             num_layers=40,
         | 
| 384 | 
            +
                             window_size=(-1, -1),
         | 
| 385 | 
            +
                             qk_norm=True,
         | 
| 386 | 
            +
                             cross_attn_norm=True,
         | 
| 387 | 
            +
                             eps=1e-6):
         | 
| 388 | 
            +
                    r"""
         | 
| 389 | 
            +
                    Initialize the diffusion model backbone.
         | 
| 390 | 
            +
             | 
| 391 | 
            +
                    Args:
         | 
| 392 | 
            +
                        model_type (`str`, *optional*, defaults to 't2v'):
         | 
| 393 | 
            +
                            Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
         | 
| 394 | 
            +
                        patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
         | 
| 395 | 
            +
                            3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
         | 
| 396 | 
            +
                        text_len (`int`, *optional*, defaults to 512):
         | 
| 397 | 
            +
                            Fixed length for text embeddings
         | 
| 398 | 
            +
                        in_dim (`int`, *optional*, defaults to 16):
         | 
| 399 | 
            +
                            Input video channels (C_in)
         | 
| 400 | 
            +
                        dim (`int`, *optional*, defaults to 2048):
         | 
| 401 | 
            +
                            Hidden dimension of the transformer
         | 
| 402 | 
            +
                        ffn_dim (`int`, *optional*, defaults to 8192):
         | 
| 403 | 
            +
                            Intermediate dimension in feed-forward network
         | 
| 404 | 
            +
                        freq_dim (`int`, *optional*, defaults to 256):
         | 
| 405 | 
            +
                            Dimension for sinusoidal time embeddings
         | 
| 406 | 
            +
                        text_dim (`int`, *optional*, defaults to 4096):
         | 
| 407 | 
            +
                            Input dimension for text embeddings
         | 
| 408 | 
            +
                        out_dim (`int`, *optional*, defaults to 16):
         | 
| 409 | 
            +
                            Output video channels (C_out)
         | 
| 410 | 
            +
                        num_heads (`int`, *optional*, defaults to 16):
         | 
| 411 | 
            +
                            Number of attention heads
         | 
| 412 | 
            +
                        num_layers (`int`, *optional*, defaults to 32):
         | 
| 413 | 
            +
                            Number of transformer blocks
         | 
| 414 | 
            +
                        window_size (`tuple`, *optional*, defaults to (-1, -1)):
         | 
| 415 | 
            +
                            Window size for local attention (-1 indicates global attention)
         | 
| 416 | 
            +
                        qk_norm (`bool`, *optional*, defaults to True):
         | 
| 417 | 
            +
                            Enable query/key normalization
         | 
| 418 | 
            +
                        cross_attn_norm (`bool`, *optional*, defaults to False):
         | 
| 419 | 
            +
                            Enable cross-attention normalization
         | 
| 420 | 
            +
                        eps (`float`, *optional*, defaults to 1e-6):
         | 
| 421 | 
            +
                            Epsilon value for normalization layers
         | 
| 422 | 
            +
                    """
         | 
| 423 | 
            +
             | 
| 424 | 
            +
                    super().__init__()
         | 
| 425 | 
            +
             | 
| 426 | 
            +
                    assert model_type in ['t2v', 'i2v']
         | 
| 427 | 
            +
                    self.model_type = model_type
         | 
| 428 | 
            +
             | 
| 429 | 
            +
                    self.patch_size = patch_size
         | 
| 430 | 
            +
                    self.text_len = text_len
         | 
| 431 | 
            +
                    self.in_dim = in_dim
         | 
| 432 | 
            +
                    self.dim = dim
         | 
| 433 | 
            +
                    self.ffn_dim = ffn_dim
         | 
| 434 | 
            +
                    self.freq_dim = freq_dim
         | 
| 435 | 
            +
                    self.text_dim = text_dim
         | 
| 436 | 
            +
                    self.out_dim = out_dim
         | 
| 437 | 
            +
                    self.num_heads = num_heads
         | 
| 438 | 
            +
                    self.num_layers = num_layers
         | 
| 439 | 
            +
                    self.window_size = window_size
         | 
| 440 | 
            +
                    self.qk_norm = qk_norm
         | 
| 441 | 
            +
                    self.cross_attn_norm = cross_attn_norm
         | 
| 442 | 
            +
                    self.eps = eps
         | 
| 443 | 
            +
             | 
| 444 | 
            +
                    # embeddings
         | 
| 445 | 
            +
                    self.patch_embedding = nn.Conv3d(
         | 
| 446 | 
            +
                        in_dim, dim, kernel_size=patch_size, stride=patch_size)
         | 
| 447 | 
            +
                    self.text_embedding = nn.Sequential(
         | 
| 448 | 
            +
                        nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
         | 
| 449 | 
            +
                        nn.Linear(dim, dim))
         | 
| 450 | 
            +
             | 
| 451 | 
            +
                    self.time_embedding = nn.Sequential(
         | 
| 452 | 
            +
                        nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
         | 
| 453 | 
            +
                    self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
         | 
| 454 | 
            +
             | 
| 455 | 
            +
                    # blocks
         | 
| 456 | 
            +
                    cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
         | 
| 457 | 
            +
                    self.blocks = nn.ModuleList([
         | 
| 458 | 
            +
                        WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
         | 
| 459 | 
            +
                                          window_size, qk_norm, cross_attn_norm, eps)
         | 
| 460 | 
            +
                        for _ in range(num_layers)
         | 
| 461 | 
            +
                    ])
         | 
| 462 | 
            +
             | 
| 463 | 
            +
                    # head
         | 
| 464 | 
            +
                    self.head = Head(dim, out_dim, patch_size, eps)
         | 
| 465 | 
            +
             | 
| 466 | 
            +
                    # buffers (don't use register_buffer otherwise dtype will be changed in to())
         | 
| 467 | 
            +
                    assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
         | 
| 468 | 
            +
                    d = dim // num_heads
         | 
| 469 | 
            +
                    self.freqs = torch.cat([
         | 
| 470 | 
            +
                        rope_params(1024, d - 4 * (d // 6)),
         | 
| 471 | 
            +
                        rope_params(1024, 2 * (d // 6)),
         | 
| 472 | 
            +
                        rope_params(1024, 2 * (d // 6))
         | 
| 473 | 
            +
                    ],
         | 
| 474 | 
            +
                                           dim=1)
         | 
| 475 | 
            +
             | 
| 476 | 
            +
                    if model_type == 'i2v':
         | 
| 477 | 
            +
                        self.img_emb = MLPProj(1280, dim)
         | 
| 478 | 
            +
             | 
| 479 | 
            +
                    # initialize weights
         | 
| 480 | 
            +
                    self.init_weights()
         | 
| 481 | 
            +
             | 
| 482 | 
            +
                def forward(
         | 
| 483 | 
            +
                    self,
         | 
| 484 | 
            +
                    x,
         | 
| 485 | 
            +
                    t,
         | 
| 486 | 
            +
                    context,
         | 
| 487 | 
            +
                    seq_len,
         | 
| 488 | 
            +
                    clip_fea=None,
         | 
| 489 | 
            +
                    y=None,
         | 
| 490 | 
            +
                ):
         | 
| 491 | 
            +
                    r"""
         | 
| 492 | 
            +
                    Forward pass through the diffusion model
         | 
| 493 | 
            +
             | 
| 494 | 
            +
                    Args:
         | 
| 495 | 
            +
                        x (List[Tensor]):
         | 
| 496 | 
            +
                            List of input video tensors, each with shape [C_in, F, H, W]
         | 
| 497 | 
            +
                        t (Tensor):
         | 
| 498 | 
            +
                            Diffusion timesteps tensor of shape [B]
         | 
| 499 | 
            +
                        context (List[Tensor]):
         | 
| 500 | 
            +
                            List of text embeddings each with shape [L, C]
         | 
| 501 | 
            +
                        seq_len (`int`):
         | 
| 502 | 
            +
                            Maximum sequence length for positional encoding
         | 
| 503 | 
            +
                        clip_fea (Tensor, *optional*):
         | 
| 504 | 
            +
                            CLIP image features for image-to-video mode
         | 
| 505 | 
            +
                        y (List[Tensor], *optional*):
         | 
| 506 | 
            +
                            Conditional video inputs for image-to-video mode, same shape as x
         | 
| 507 | 
            +
             | 
| 508 | 
            +
                    Returns:
         | 
| 509 | 
            +
                        List[Tensor]:
         | 
| 510 | 
            +
                            List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
         | 
| 511 | 
            +
                    """
         | 
| 512 | 
            +
                    if self.model_type == 'i2v':
         | 
| 513 | 
            +
                        assert clip_fea is not None and y is not None
         | 
| 514 | 
            +
                    # params
         | 
| 515 | 
            +
                    device = self.patch_embedding.weight.device
         | 
| 516 | 
            +
                    freqs = self.freqs.to(device)
         | 
| 517 | 
            +
             | 
| 518 | 
            +
                    if y is not None:
         | 
| 519 | 
            +
                        x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
         | 
| 520 | 
            +
             | 
| 521 | 
            +
                    # embeddings
         | 
| 522 | 
            +
                    x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
         | 
| 523 | 
            +
                    grid_sizes = torch.stack(
         | 
| 524 | 
            +
                        [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
         | 
| 525 | 
            +
                    x = [u.flatten(2).transpose(1, 2) for u in x]
         | 
| 526 | 
            +
                    seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
         | 
| 527 | 
            +
                    assert seq_lens.max() <= seq_len
         | 
| 528 | 
            +
                    x = torch.cat([
         | 
| 529 | 
            +
                        torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
         | 
| 530 | 
            +
                                  dim=1) for u in x
         | 
| 531 | 
            +
                    ])
         | 
| 532 | 
            +
             | 
| 533 | 
            +
                    # time embeddings
         | 
| 534 | 
            +
                    with torch.amp.autocast('cuda', dtype=torch.float32):
         | 
| 535 | 
            +
                        e = self.time_embedding(
         | 
| 536 | 
            +
                            sinusoidal_embedding_1d(self.freq_dim, t).float())
         | 
| 537 | 
            +
                        e0 = self.time_projection(e).unflatten(1, (6, self.dim))
         | 
| 538 | 
            +
                        assert e.dtype == torch.float32 and e0.dtype == torch.float32
         | 
| 539 | 
            +
             | 
| 540 | 
            +
                    # context
         | 
| 541 | 
            +
                    context_lens = None
         | 
| 542 | 
            +
                    context = self.text_embedding(
         | 
| 543 | 
            +
                        torch.stack([
         | 
| 544 | 
            +
                            torch.cat(
         | 
| 545 | 
            +
                                [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
         | 
| 546 | 
            +
                            for u in context
         | 
| 547 | 
            +
                        ]))
         | 
| 548 | 
            +
             | 
| 549 | 
            +
                    if clip_fea is not None:
         | 
| 550 | 
            +
                        context_clip = self.img_emb(clip_fea)  # bs x 257 x dim
         | 
| 551 | 
            +
                        context = torch.concat([context_clip, context], dim=1)
         | 
| 552 | 
            +
             | 
| 553 | 
            +
                    # arguments
         | 
| 554 | 
            +
                    kwargs = dict(
         | 
| 555 | 
            +
                        e=e0,
         | 
| 556 | 
            +
                        seq_lens=seq_lens,
         | 
| 557 | 
            +
                        grid_sizes=grid_sizes,
         | 
| 558 | 
            +
                        freqs=freqs,
         | 
| 559 | 
            +
                        context=context,
         | 
| 560 | 
            +
                        context_lens=context_lens)
         | 
| 561 | 
            +
             | 
| 562 | 
            +
                    for block in self.blocks:
         | 
| 563 | 
            +
                        x = block(x, **kwargs)
         | 
| 564 | 
            +
             | 
| 565 | 
            +
                    # head
         | 
| 566 | 
            +
                    x = self.head(x, e)
         | 
| 567 | 
            +
             | 
| 568 | 
            +
                    # unpatchify
         | 
| 569 | 
            +
                    x = self.unpatchify(x, grid_sizes)
         | 
| 570 | 
            +
                    return [u.float() for u in x]
         | 
| 571 | 
            +
             | 
| 572 | 
            +
                def unpatchify(self, x, grid_sizes):
         | 
| 573 | 
            +
                    r"""
         | 
| 574 | 
            +
                    Reconstruct video tensors from patch embeddings.
         | 
| 575 | 
            +
             | 
| 576 | 
            +
                    Args:
         | 
| 577 | 
            +
                        x (List[Tensor]):
         | 
| 578 | 
            +
                            List of patchified features, each with shape [L, C_out * prod(patch_size)]
         | 
| 579 | 
            +
                        grid_sizes (Tensor):
         | 
| 580 | 
            +
                            Original spatial-temporal grid dimensions before patching,
         | 
| 581 | 
            +
                                shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
         | 
| 582 | 
            +
             | 
| 583 | 
            +
                    Returns:
         | 
| 584 | 
            +
                        List[Tensor]:
         | 
| 585 | 
            +
                            Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
         | 
| 586 | 
            +
                    """
         | 
| 587 | 
            +
             | 
| 588 | 
            +
                    c = self.out_dim
         | 
| 589 | 
            +
                    out = []
         | 
| 590 | 
            +
                    for u, v in zip(x, grid_sizes.tolist()):
         | 
| 591 | 
            +
                        u = u[:math.prod(v)].view(*v, *self.patch_size, c)
         | 
| 592 | 
            +
                        u = torch.einsum('fhwpqrc->cfphqwr', u)
         | 
| 593 | 
            +
                        u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
         | 
| 594 | 
            +
                        out.append(u)
         | 
| 595 | 
            +
                    return out
         | 
| 596 | 
            +
             | 
| 597 | 
            +
                def init_weights(self):
         | 
| 598 | 
            +
                    r"""
         | 
| 599 | 
            +
                    Initialize model parameters using Xavier initialization.
         | 
| 600 | 
            +
                    """
         | 
| 601 | 
            +
             | 
| 602 | 
            +
                    # basic init
         | 
| 603 | 
            +
                    for m in self.modules():
         | 
| 604 | 
            +
                        if isinstance(m, nn.Linear):
         | 
| 605 | 
            +
                            nn.init.xavier_uniform_(m.weight)
         | 
| 606 | 
            +
                            if m.bias is not None:
         | 
| 607 | 
            +
                                nn.init.zeros_(m.bias)
         | 
| 608 | 
            +
             | 
| 609 | 
            +
                    # init embeddings
         | 
| 610 | 
            +
                    nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
         | 
| 611 | 
            +
                    for m in self.text_embedding.modules():
         | 
| 612 | 
            +
                        if isinstance(m, nn.Linear):
         | 
| 613 | 
            +
                            nn.init.normal_(m.weight, std=.02)
         | 
| 614 | 
            +
                    for m in self.time_embedding.modules():
         | 
| 615 | 
            +
                        if isinstance(m, nn.Linear):
         | 
| 616 | 
            +
                            nn.init.normal_(m.weight, std=.02)
         | 
| 617 | 
            +
             | 
| 618 | 
            +
                    # init output layer
         | 
| 619 | 
            +
                    nn.init.zeros_(self.head.head.weight)
         | 
    	
        humo/models/wan_modules/model_humo.py
    ADDED
    
    | @@ -0,0 +1,803 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            from torch import nn
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from common.distributed import get_device
         | 
| 5 | 
            +
            from models.audio.audio_proj import AudioProjModel
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import torch.cuda.amp as amp
         | 
| 8 | 
            +
            import math
         | 
| 9 | 
            +
            from humo.models.wan_modules.attention import flash_attention
         | 
| 10 | 
            +
            from common.distributed.advanced import is_unified_parallel_initialized
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import types
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            def sinusoidal_embedding_1d(dim, position):
         | 
| 15 | 
            +
                # preprocess
         | 
| 16 | 
            +
                assert dim % 2 == 0
         | 
| 17 | 
            +
                half = dim // 2
         | 
| 18 | 
            +
                position = position.type(torch.float64)
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                # calculation
         | 
| 21 | 
            +
                sinusoid = torch.outer(
         | 
| 22 | 
            +
                    position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
         | 
| 23 | 
            +
                x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
         | 
| 24 | 
            +
                return x
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            @amp.autocast(enabled=False)
         | 
| 28 | 
            +
            def rope_params(max_seq_len, dim, theta=10000):
         | 
| 29 | 
            +
                assert dim % 2 == 0
         | 
| 30 | 
            +
                freqs = torch.outer(
         | 
| 31 | 
            +
                    torch.arange(max_seq_len),
         | 
| 32 | 
            +
                    1.0 / torch.pow(theta,
         | 
| 33 | 
            +
                                    torch.arange(0, dim, 2).to(torch.float32).div(dim)))
         | 
| 34 | 
            +
                freqs = torch.polar(torch.ones_like(freqs), freqs)
         | 
| 35 | 
            +
                return freqs
         | 
| 36 | 
            +
             | 
| 37 | 
            +
             | 
| 38 | 
            +
            @amp.autocast(enabled=False)
         | 
| 39 | 
            +
            def rope_apply(x, grid_sizes, freqs):
         | 
| 40 | 
            +
                n, c = x.size(2), x.size(3) // 2
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                # split freqs
         | 
| 43 | 
            +
                freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
         | 
| 44 | 
            +
                
         | 
| 45 | 
            +
                # loop over samples
         | 
| 46 | 
            +
                output = []
         | 
| 47 | 
            +
                for i, (f, h, w) in enumerate(grid_sizes.tolist()):
         | 
| 48 | 
            +
                    seq_len = f * h * w
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                    # precompute multipliers
         | 
| 51 | 
            +
                    x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float32).reshape(
         | 
| 52 | 
            +
                        seq_len, n, -1, 2))
         | 
| 53 | 
            +
                    freqs_i = torch.cat([
         | 
| 54 | 
            +
                        freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
         | 
| 55 | 
            +
                        freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
         | 
| 56 | 
            +
                        freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
         | 
| 57 | 
            +
                    ],
         | 
| 58 | 
            +
                                        dim=-1).reshape(seq_len, 1, -1)
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                    # apply rotary embedding
         | 
| 61 | 
            +
                    x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
         | 
| 62 | 
            +
                    x_i = torch.cat([x_i, x[i, seq_len:]])
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    # append to collection
         | 
| 65 | 
            +
                    output.append(x_i)
         | 
| 66 | 
            +
                return torch.stack(output).float()
         | 
| 67 | 
            +
             | 
| 68 | 
            +
             | 
| 69 | 
            +
            class WanRMSNorm(nn.Module):
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                def __init__(self, dim, eps=1e-5):
         | 
| 72 | 
            +
                    super().__init__()
         | 
| 73 | 
            +
                    self.dim = dim
         | 
| 74 | 
            +
                    self.eps = eps
         | 
| 75 | 
            +
                    self.weight = nn.Parameter(torch.ones(dim))
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                def forward(self, x):
         | 
| 78 | 
            +
                    r"""
         | 
| 79 | 
            +
                    Args:
         | 
| 80 | 
            +
                        x(Tensor): Shape [B, L, C]
         | 
| 81 | 
            +
                    """
         | 
| 82 | 
            +
                    return self._norm(x.float()).type_as(x) * self.weight
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                def _norm(self, x):
         | 
| 85 | 
            +
                    return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
         | 
| 86 | 
            +
             | 
| 87 | 
            +
             | 
| 88 | 
            +
            class WanLayerNorm(nn.LayerNorm):
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                def __init__(self, dim, eps=1e-6, elementwise_affine=False):
         | 
| 91 | 
            +
                    super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                def forward(self, x):
         | 
| 94 | 
            +
                    r"""
         | 
| 95 | 
            +
                    Args:
         | 
| 96 | 
            +
                        x(Tensor): Shape [B, L, C]
         | 
| 97 | 
            +
                    """
         | 
| 98 | 
            +
                    return super().forward(x.float()).type_as(x)
         | 
| 99 | 
            +
             | 
| 100 | 
            +
             | 
| 101 | 
            +
            class WanSelfAttention(nn.Module):
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                def __init__(self,
         | 
| 104 | 
            +
                             dim,
         | 
| 105 | 
            +
                             num_heads,
         | 
| 106 | 
            +
                             window_size=(-1, -1),
         | 
| 107 | 
            +
                             qk_norm=True,
         | 
| 108 | 
            +
                             eps=1e-6):
         | 
| 109 | 
            +
                    assert dim % num_heads == 0
         | 
| 110 | 
            +
                    super().__init__()
         | 
| 111 | 
            +
                    self.dim = dim
         | 
| 112 | 
            +
                    self.num_heads = num_heads
         | 
| 113 | 
            +
                    self.head_dim = dim // num_heads
         | 
| 114 | 
            +
                    self.window_size = window_size
         | 
| 115 | 
            +
                    self.qk_norm = qk_norm
         | 
| 116 | 
            +
                    self.eps = eps
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                    # layers
         | 
| 119 | 
            +
                    self.q = nn.Linear(dim, dim)
         | 
| 120 | 
            +
                    self.k = nn.Linear(dim, dim)
         | 
| 121 | 
            +
                    self.v = nn.Linear(dim, dim)
         | 
| 122 | 
            +
                    self.o = nn.Linear(dim, dim)
         | 
| 123 | 
            +
                    self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
         | 
| 124 | 
            +
                    self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                def forward(self, x, seq_lens, grid_sizes, freqs):
         | 
| 127 | 
            +
                    r"""
         | 
| 128 | 
            +
                    Args:
         | 
| 129 | 
            +
                        x(Tensor): Shape [B, L, num_heads, C / num_heads], torch.Size([1, 9360, 5120])
         | 
| 130 | 
            +
                        seq_lens(Tensor): Shape [B], tensor([9360])
         | 
| 131 | 
            +
                        grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W), tensor([[ 6, 30, 52]])
         | 
| 132 | 
            +
                        freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
         | 
| 133 | 
            +
                    """
         | 
| 134 | 
            +
                    b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                    # query, key, value function
         | 
| 137 | 
            +
                    def qkv_fn(x):
         | 
| 138 | 
            +
                        q = self.norm_q(self.q(x)).view(b, s, n, d)
         | 
| 139 | 
            +
                        k = self.norm_k(self.k(x)).view(b, s, n, d)
         | 
| 140 | 
            +
                        v = self.v(x).view(b, s, n, d)
         | 
| 141 | 
            +
                        return q, k, v
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                    q, k, v = qkv_fn(x)
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                    x = flash_attention(
         | 
| 146 | 
            +
                        q=rope_apply(q, grid_sizes, freqs),
         | 
| 147 | 
            +
                        k=rope_apply(k, grid_sizes, freqs),
         | 
| 148 | 
            +
                        v=v,
         | 
| 149 | 
            +
                        k_lens=seq_lens,
         | 
| 150 | 
            +
                        window_size=self.window_size)
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                    # output
         | 
| 153 | 
            +
                    x = x.flatten(2)
         | 
| 154 | 
            +
                    x = self.o(x)
         | 
| 155 | 
            +
                    return x
         | 
| 156 | 
            +
             | 
| 157 | 
            +
             | 
| 158 | 
            +
            class WanSelfAttentionSepKVDim(nn.Module):
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                def __init__(self,
         | 
| 161 | 
            +
                             kv_dim,
         | 
| 162 | 
            +
                             dim,
         | 
| 163 | 
            +
                             num_heads,
         | 
| 164 | 
            +
                             window_size=(-1, -1),
         | 
| 165 | 
            +
                             qk_norm=True,
         | 
| 166 | 
            +
                             eps=1e-6):
         | 
| 167 | 
            +
                    assert dim % num_heads == 0
         | 
| 168 | 
            +
                    super().__init__()
         | 
| 169 | 
            +
                    self.dim = dim
         | 
| 170 | 
            +
                    self.num_heads = num_heads
         | 
| 171 | 
            +
                    self.head_dim = dim // num_heads
         | 
| 172 | 
            +
                    self.window_size = window_size
         | 
| 173 | 
            +
                    self.qk_norm = qk_norm
         | 
| 174 | 
            +
                    self.eps = eps
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                    # layers
         | 
| 177 | 
            +
                    self.q = nn.Linear(dim, dim)
         | 
| 178 | 
            +
                    self.k = nn.Linear(kv_dim, dim)
         | 
| 179 | 
            +
                    self.v = nn.Linear(kv_dim, dim)
         | 
| 180 | 
            +
                    self.o = nn.Linear(dim, dim)
         | 
| 181 | 
            +
                    self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
         | 
| 182 | 
            +
                    self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                def forward(self, x, seq_lens, grid_sizes, freqs):
         | 
| 185 | 
            +
                    r"""
         | 
| 186 | 
            +
                    Args:
         | 
| 187 | 
            +
                        x(Tensor): Shape [B, L, num_heads, C / num_heads], torch.Size([1, 9360, 5120])
         | 
| 188 | 
            +
                        seq_lens(Tensor): Shape [B], tensor([9360])
         | 
| 189 | 
            +
                        grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W), tensor([[ 6, 30, 52]])
         | 
| 190 | 
            +
                        freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
         | 
| 191 | 
            +
                    """
         | 
| 192 | 
            +
                    b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                    # query, key, value function
         | 
| 195 | 
            +
                    def qkv_fn(x):
         | 
| 196 | 
            +
                        q = self.norm_q(self.q(x)).view(b, s, n, d)
         | 
| 197 | 
            +
                        k = self.norm_k(self.k(x)).view(b, s, n, d)
         | 
| 198 | 
            +
                        v = self.v(x).view(b, s, n, d)
         | 
| 199 | 
            +
                        return q, k, v
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                    q, k, v = qkv_fn(x)
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                    x = flash_attention(
         | 
| 204 | 
            +
                        q=rope_apply(q, grid_sizes, freqs),
         | 
| 205 | 
            +
                        k=rope_apply(k, grid_sizes, freqs),
         | 
| 206 | 
            +
                        v=v,
         | 
| 207 | 
            +
                        k_lens=seq_lens,
         | 
| 208 | 
            +
                        window_size=self.window_size)
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                    # output
         | 
| 211 | 
            +
                    x = x.flatten(2)
         | 
| 212 | 
            +
                    x = self.o(x)
         | 
| 213 | 
            +
                    return x
         | 
| 214 | 
            +
             | 
| 215 | 
            +
             | 
| 216 | 
            +
             | 
| 217 | 
            +
            class WanT2VCrossAttention(WanSelfAttention):
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                def forward(self, x, context, context_lens):
         | 
| 220 | 
            +
                    r"""
         | 
| 221 | 
            +
                    Args:
         | 
| 222 | 
            +
                        x(Tensor): Shape [B, L1, C]
         | 
| 223 | 
            +
                        context(Tensor): Shape [B, L2, C]
         | 
| 224 | 
            +
                        context_lens(Tensor): Shape [B]
         | 
| 225 | 
            +
                    """
         | 
| 226 | 
            +
                    b, n, d = x.size(0), self.num_heads, self.head_dim
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                    # compute query, key, value
         | 
| 229 | 
            +
                    q = self.norm_q(self.q(x)).view(b, -1, n, d)
         | 
| 230 | 
            +
                    k = self.norm_k(self.k(context)).view(b, -1, n, d)
         | 
| 231 | 
            +
                    v = self.v(context).view(b, -1, n, d)
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                    # compute attention
         | 
| 234 | 
            +
                    x = flash_attention(q, k, v, k_lens=context_lens)
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                    # output
         | 
| 237 | 
            +
                    x = x.flatten(2)
         | 
| 238 | 
            +
                    x = self.o(x)
         | 
| 239 | 
            +
                    return x
         | 
| 240 | 
            +
             | 
| 241 | 
            +
             | 
| 242 | 
            +
            class WanT2VCrossAttentionGather(WanSelfAttentionSepKVDim):
         | 
| 243 | 
            +
             | 
| 244 | 
            +
                def forward(self, x, context, context_lens, grid_sizes, freqs, audio_seq_len):
         | 
| 245 | 
            +
                    b, n, d = x.size(0), self.num_heads, self.head_dim
         | 
| 246 | 
            +
             | 
| 247 | 
            +
                    q = self.norm_q(self.q(x)).view(b, -1, n, d)
         | 
| 248 | 
            +
                    k = self.norm_k(self.k(context)).view(b, -1, n, d)
         | 
| 249 | 
            +
                    v = self.v(context).view(b, -1, n, d)
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                    # --- NEW: derive sizes from shapes (SymInts), no int(tensor) casts ---
         | 
| 252 | 
            +
                    Lq = q.shape[1]                 # total video tokens per sample
         | 
| 253 | 
            +
                    # audio has 16 tokens per frame -> frames = audio_tokens // 16
         | 
| 254 | 
            +
                    frames = (context.shape[1] // 16)
         | 
| 255 | 
            +
                    hlen_wlen = Lq // frames        # tokens per frame = H*W
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                    # Now reshape using SymInt-derived sizes
         | 
| 258 | 
            +
                    q = q.reshape(-1, hlen_wlen, n, d)
         | 
| 259 | 
            +
                    k = k.reshape(-1, 16, n, d)
         | 
| 260 | 
            +
                    v = v.reshape(-1, 16, n, d)
         | 
| 261 | 
            +
             | 
| 262 | 
            +
                    x = flash_attention(q, k, v, k_lens=None)
         | 
| 263 | 
            +
                    x = x.view(b, -1, n, d).flatten(2)
         | 
| 264 | 
            +
                    x = self.o(x)
         | 
| 265 | 
            +
                    return x
         | 
| 266 | 
            +
             | 
| 267 | 
            +
                # def forward(self, x, context, context_lens, grid_sizes, freqs, audio_seq_len):
         | 
| 268 | 
            +
                #     r"""
         | 
| 269 | 
            +
                #     Args:
         | 
| 270 | 
            +
                #         x(Tensor): Shape [B, L1, C] - video tokens
         | 
| 271 | 
            +
                #         context(Tensor): Shape [B, L2, C] - audio tokens with shape [B, frames*16, 1536]
         | 
| 272 | 
            +
                #         context_lens(Tensor): Shape [B] - actually seq_lens from call (video sequence length)
         | 
| 273 | 
            +
                #         grid_sizes(Tensor): Shape [B, 3] - video grid dimensions (F, H, W)
         | 
| 274 | 
            +
                #         freqs(Tensor): RoPE frequencies
         | 
| 275 | 
            +
                #         audio_seq_len(Tensor): Actual audio sequence length (frames * 16)
         | 
| 276 | 
            +
                #     """
         | 
| 277 | 
            +
                #     b, n, d = x.size(0), self.num_heads, self.head_dim
         | 
| 278 | 
            +
             | 
| 279 | 
            +
                #     q = self.norm_q(self.q(x)).view(b, -1, n, d)
         | 
| 280 | 
            +
                #     k = self.norm_k(self.k(context)).view(b, -1, n, d)
         | 
| 281 | 
            +
                #     v = self.v(context).view(b, -1, n, d)
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                #     # Handle video spatial structure
         | 
| 284 | 
            +
                #     hlen_wlen = int(grid_sizes[0][1] * grid_sizes[0][2])
         | 
| 285 | 
            +
                #     q = q.reshape(-1, hlen_wlen, n, d)
         | 
| 286 | 
            +
                    
         | 
| 287 | 
            +
                #     # Handle audio temporal structure (16 tokens per frame)
         | 
| 288 | 
            +
                #     k = k.reshape(-1, 16, n, d)
         | 
| 289 | 
            +
                #     v = v.reshape(-1, 16, n, d)
         | 
| 290 | 
            +
             | 
| 291 | 
            +
                #     # Cross-attention
         | 
| 292 | 
            +
                #     x = flash_attention(q, k, v, k_lens=None)  # No masking for audio
         | 
| 293 | 
            +
                    
         | 
| 294 | 
            +
                #     x = x.view(b, -1, n, d).flatten(2)
         | 
| 295 | 
            +
                #     x = self.o(x)
         | 
| 296 | 
            +
                #     return x
         | 
| 297 | 
            +
             | 
| 298 | 
            +
             | 
| 299 | 
            +
            class AudioCrossAttentionWrapper(nn.Module):
         | 
| 300 | 
            +
                def __init__(self, dim, kv_dim, num_heads, qk_norm=True, eps=1e-6,):
         | 
| 301 | 
            +
                    super().__init__()
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                    self.audio_cross_attn = WanT2VCrossAttentionGather(
         | 
| 304 | 
            +
                            kv_dim, dim, num_heads, (-1, -1), qk_norm, eps)
         | 
| 305 | 
            +
                    self.norm1_audio = WanLayerNorm(dim, eps,
         | 
| 306 | 
            +
                        elementwise_affine=True)
         | 
| 307 | 
            +
             | 
| 308 | 
            +
                def forward(self, x, audio, seq_lens, grid_sizes, freqs, audio_seq_len):
         | 
| 309 | 
            +
                    x = x + self.audio_cross_attn(
         | 
| 310 | 
            +
                        self.norm1_audio(x), audio, seq_lens, grid_sizes, freqs, audio_seq_len)
         | 
| 311 | 
            +
                    return x
         | 
| 312 | 
            +
                    
         | 
| 313 | 
            +
             | 
| 314 | 
            +
            class WanI2VCrossAttention(WanSelfAttention):
         | 
| 315 | 
            +
             | 
| 316 | 
            +
                def __init__(self,
         | 
| 317 | 
            +
                             dim,
         | 
| 318 | 
            +
                             num_heads,
         | 
| 319 | 
            +
                             window_size=(-1, -1),
         | 
| 320 | 
            +
                             qk_norm=True,
         | 
| 321 | 
            +
                             eps=1e-6):
         | 
| 322 | 
            +
                    super().__init__(dim, num_heads, window_size, qk_norm, eps)
         | 
| 323 | 
            +
             | 
| 324 | 
            +
                def forward(self, x, context, context_lens):
         | 
| 325 | 
            +
                    r"""
         | 
| 326 | 
            +
                    Args:
         | 
| 327 | 
            +
                        x(Tensor): Shape [B, L1, C]
         | 
| 328 | 
            +
                        context(Tensor): Shape [B, L2, C]
         | 
| 329 | 
            +
                        context_lens(Tensor): Shape [B]
         | 
| 330 | 
            +
                    """
         | 
| 331 | 
            +
                    b, n, d = x.size(0), self.num_heads, self.head_dim
         | 
| 332 | 
            +
             | 
| 333 | 
            +
                    # compute query, key, value
         | 
| 334 | 
            +
                    q = self.norm_q(self.q(x)).view(b, -1, n, d)
         | 
| 335 | 
            +
                    k = self.norm_k(self.k(context)).view(b, -1, n, d)
         | 
| 336 | 
            +
                    v = self.v(context).view(b, -1, n, d)
         | 
| 337 | 
            +
                    x = flash_attention(q, k, v, k_lens=context_lens)
         | 
| 338 | 
            +
                    
         | 
| 339 | 
            +
                    # output
         | 
| 340 | 
            +
                    x = x.flatten(2)
         | 
| 341 | 
            +
                    x = self.o(x)
         | 
| 342 | 
            +
                    return x
         | 
| 343 | 
            +
             | 
| 344 | 
            +
             | 
| 345 | 
            +
            WAN_CROSSATTENTION_CLASSES = {
         | 
| 346 | 
            +
                't2v_cross_attn': WanT2VCrossAttention,
         | 
| 347 | 
            +
                'i2v_cross_attn': WanI2VCrossAttention,
         | 
| 348 | 
            +
            }
         | 
| 349 | 
            +
             | 
| 350 | 
            +
            class WanAttentionBlock(nn.Module):
         | 
| 351 | 
            +
             | 
| 352 | 
            +
                def __init__(self,
         | 
| 353 | 
            +
                             cross_attn_type,
         | 
| 354 | 
            +
                             dim,
         | 
| 355 | 
            +
                             ffn_dim,
         | 
| 356 | 
            +
                             num_heads,
         | 
| 357 | 
            +
                             window_size=(-1, -1),
         | 
| 358 | 
            +
                             qk_norm=True,
         | 
| 359 | 
            +
                             cross_attn_norm=False,
         | 
| 360 | 
            +
                             eps=1e-6,
         | 
| 361 | 
            +
                             use_audio=True):
         | 
| 362 | 
            +
                    super().__init__()
         | 
| 363 | 
            +
                    self.dim = dim
         | 
| 364 | 
            +
                    self.ffn_dim = ffn_dim
         | 
| 365 | 
            +
                    self.num_heads = num_heads
         | 
| 366 | 
            +
                    self.window_size = window_size
         | 
| 367 | 
            +
                    self.qk_norm = qk_norm
         | 
| 368 | 
            +
                    self.cross_attn_norm = cross_attn_norm
         | 
| 369 | 
            +
                    self.eps = eps
         | 
| 370 | 
            +
             | 
| 371 | 
            +
                    # layers
         | 
| 372 | 
            +
                    self.norm1 = WanLayerNorm(dim, eps)
         | 
| 373 | 
            +
                    self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
         | 
| 374 | 
            +
                                                      eps)
         | 
| 375 | 
            +
                    self.norm3 = WanLayerNorm(
         | 
| 376 | 
            +
                        dim, eps,
         | 
| 377 | 
            +
                        elementwise_affine=True) if cross_attn_norm else nn.Identity()
         | 
| 378 | 
            +
                    self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim,
         | 
| 379 | 
            +
                                                                                  num_heads,
         | 
| 380 | 
            +
                                                                                  (-1, -1),
         | 
| 381 | 
            +
                                                                                  qk_norm,
         | 
| 382 | 
            +
                                                                                  eps)
         | 
| 383 | 
            +
                    self.norm2 = WanLayerNorm(dim, eps)
         | 
| 384 | 
            +
                    self.ffn = nn.Sequential(
         | 
| 385 | 
            +
                        nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
         | 
| 386 | 
            +
                        nn.Linear(ffn_dim, dim))
         | 
| 387 | 
            +
             | 
| 388 | 
            +
                    # modulation
         | 
| 389 | 
            +
                    self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
         | 
| 390 | 
            +
             | 
| 391 | 
            +
                    self.use_audio = use_audio
         | 
| 392 | 
            +
                    if use_audio:
         | 
| 393 | 
            +
                        self.audio_cross_attn_wrapper = AudioCrossAttentionWrapper(dim, 1536, num_heads, qk_norm, eps)
         | 
| 394 | 
            +
             | 
| 395 | 
            +
                def forward(
         | 
| 396 | 
            +
                    self,
         | 
| 397 | 
            +
                    x, # torch.Size([1, 9360, 5120])
         | 
| 398 | 
            +
                    e, # torch.Size([1, 6, 5120])
         | 
| 399 | 
            +
                    seq_lens, # tensor([9360])
         | 
| 400 | 
            +
                    grid_sizes, # tensor([[ 6, 30, 52]])
         | 
| 401 | 
            +
                    freqs, # torch.Size([1024, 64])
         | 
| 402 | 
            +
                    context, # torch.Size([1, 512, 5120])
         | 
| 403 | 
            +
                    context_lens, # None
         | 
| 404 | 
            +
                    audio=None, # None
         | 
| 405 | 
            +
                    audio_seq_len=None,
         | 
| 406 | 
            +
                    ref_num_list=None,
         | 
| 407 | 
            +
                ):
         | 
| 408 | 
            +
                    r"""
         | 
| 409 | 
            +
                    Args:
         | 
| 410 | 
            +
                        x(Tensor): Shape [B, L, C]
         | 
| 411 | 
            +
                        e(Tensor): Shape [B, L, C]
         | 
| 412 | 
            +
                        audio(Tensor): Shape [B, L, C]
         | 
| 413 | 
            +
                        seq_lens(Tensor): Shape [B], length of each sequence in batch
         | 
| 414 | 
            +
                        grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
         | 
| 415 | 
            +
                        freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
         | 
| 416 | 
            +
                        ref_num_list: 配合seq_lens可以查到reference image在倒数第几个
         | 
| 417 | 
            +
                    """
         | 
| 418 | 
            +
                    assert e.dtype == torch.float32
         | 
| 419 | 
            +
                    with torch.amp.autocast('cuda', dtype=torch.float32):
         | 
| 420 | 
            +
                        e = (self.modulation + e).chunk(6, dim=1)
         | 
| 421 | 
            +
                    assert e[0].dtype == torch.float32
         | 
| 422 | 
            +
             | 
| 423 | 
            +
                    # self-attention
         | 
| 424 | 
            +
                    y = self.self_attn(
         | 
| 425 | 
            +
                        self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes,
         | 
| 426 | 
            +
                        freqs)
         | 
| 427 | 
            +
                    with torch.amp.autocast('cuda', dtype=torch.float32):
         | 
| 428 | 
            +
                        x = x + y * e[2]
         | 
| 429 | 
            +
             | 
| 430 | 
            +
                    # cross-attention & ffn function
         | 
| 431 | 
            +
                    def cross_attn_ffn(x, context, context_lens, e):
         | 
| 432 | 
            +
                        x = x + self.cross_attn(self.norm3(x), context, context_lens)
         | 
| 433 | 
            +
                        
         | 
| 434 | 
            +
                        if self.use_audio:
         | 
| 435 | 
            +
                            x = self.audio_cross_attn_wrapper(x, audio, seq_lens, grid_sizes, freqs, audio_seq_len)
         | 
| 436 | 
            +
             | 
| 437 | 
            +
                        y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
         | 
| 438 | 
            +
                        with torch.amp.autocast('cuda', dtype=torch.float32):
         | 
| 439 | 
            +
                            x = x + y * e[5]
         | 
| 440 | 
            +
                        return x
         | 
| 441 | 
            +
             | 
| 442 | 
            +
                    x = cross_attn_ffn(x, context, context_lens, e)
         | 
| 443 | 
            +
             | 
| 444 | 
            +
                    return x
         | 
| 445 | 
            +
             | 
| 446 | 
            +
             | 
| 447 | 
            +
            class Head(nn.Module):
         | 
| 448 | 
            +
             | 
| 449 | 
            +
                def __init__(self, dim, out_dim, patch_size, eps=1e-6):
         | 
| 450 | 
            +
                    super().__init__()
         | 
| 451 | 
            +
                    self.dim = dim
         | 
| 452 | 
            +
                    self.out_dim = out_dim
         | 
| 453 | 
            +
                    self.patch_size = patch_size
         | 
| 454 | 
            +
                    self.eps = eps
         | 
| 455 | 
            +
             | 
| 456 | 
            +
                    # layers
         | 
| 457 | 
            +
                    out_dim = math.prod(patch_size) * out_dim
         | 
| 458 | 
            +
                    self.norm = WanLayerNorm(dim, eps)
         | 
| 459 | 
            +
                    self.head = nn.Linear(dim, out_dim)
         | 
| 460 | 
            +
             | 
| 461 | 
            +
                    # modulation
         | 
| 462 | 
            +
                    self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
         | 
| 463 | 
            +
             | 
| 464 | 
            +
                def forward(self, x, e):
         | 
| 465 | 
            +
                    r"""
         | 
| 466 | 
            +
                    Args:
         | 
| 467 | 
            +
                        x(Tensor): Shape [B, L1, C]
         | 
| 468 | 
            +
                        e(Tensor): Shape [B, C]
         | 
| 469 | 
            +
                    """
         | 
| 470 | 
            +
                    assert e.dtype == torch.float32
         | 
| 471 | 
            +
                    with torch.amp.autocast('cuda', dtype=torch.float32):
         | 
| 472 | 
            +
                        e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
         | 
| 473 | 
            +
                        x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
         | 
| 474 | 
            +
                    return x
         | 
| 475 | 
            +
             | 
| 476 | 
            +
             | 
| 477 | 
            +
            class MLPProj(torch.nn.Module):
         | 
| 478 | 
            +
             | 
| 479 | 
            +
                def __init__(self, in_dim, out_dim):
         | 
| 480 | 
            +
                    super().__init__()
         | 
| 481 | 
            +
             | 
| 482 | 
            +
                    self.proj = torch.nn.Sequential(
         | 
| 483 | 
            +
                        torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim),
         | 
| 484 | 
            +
                        torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim),
         | 
| 485 | 
            +
                        torch.nn.LayerNorm(out_dim))
         | 
| 486 | 
            +
             | 
| 487 | 
            +
                def forward(self, image_embeds):
         | 
| 488 | 
            +
                    clip_extra_context_tokens = self.proj(image_embeds)
         | 
| 489 | 
            +
                    return clip_extra_context_tokens
         | 
| 490 | 
            +
             | 
| 491 | 
            +
             | 
| 492 | 
            +
            class WanModel(nn.Module):
         | 
| 493 | 
            +
                r"""
         | 
| 494 | 
            +
                Wan diffusion backbone supporting both text-to-video and image-to-video.
         | 
| 495 | 
            +
                """
         | 
| 496 | 
            +
             | 
| 497 | 
            +
                ignore_for_config = [
         | 
| 498 | 
            +
                    'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size'
         | 
| 499 | 
            +
                ]
         | 
| 500 | 
            +
                _no_split_modules = ['WanAttentionBlock']
         | 
| 501 | 
            +
             | 
| 502 | 
            +
                gradient_checkpointing = False
         | 
| 503 | 
            +
             | 
| 504 | 
            +
                def __init__(self,
         | 
| 505 | 
            +
                             model_type='t2v',
         | 
| 506 | 
            +
                             patch_size=(1, 2, 2),
         | 
| 507 | 
            +
                             text_len=512,
         | 
| 508 | 
            +
                             in_dim=16,
         | 
| 509 | 
            +
                             dim=2048,
         | 
| 510 | 
            +
                             ffn_dim=13824,
         | 
| 511 | 
            +
                             freq_dim=256,
         | 
| 512 | 
            +
                             text_dim=4096,
         | 
| 513 | 
            +
                             out_dim=16,
         | 
| 514 | 
            +
                             num_heads=40,
         | 
| 515 | 
            +
                             num_layers=40,
         | 
| 516 | 
            +
                             window_size=(-1, -1),
         | 
| 517 | 
            +
                             qk_norm=True,
         | 
| 518 | 
            +
                             cross_attn_norm=True,
         | 
| 519 | 
            +
                             eps=1e-6,
         | 
| 520 | 
            +
                             audio_token_num=16,
         | 
| 521 | 
            +
                             insert_audio=True):
         | 
| 522 | 
            +
                    r"""
         | 
| 523 | 
            +
                    Initialize the diffusion model backbone.
         | 
| 524 | 
            +
             | 
| 525 | 
            +
                    Args:
         | 
| 526 | 
            +
                        model_type (`str`, *optional*, defaults to 't2v'):
         | 
| 527 | 
            +
                            Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
         | 
| 528 | 
            +
                        patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
         | 
| 529 | 
            +
                            3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
         | 
| 530 | 
            +
                        text_len (`int`, *optional*, defaults to 512):
         | 
| 531 | 
            +
                            Fixed length for text embeddings
         | 
| 532 | 
            +
                        in_dim (`int`, *optional*, defaults to 16):
         | 
| 533 | 
            +
                            Input video channels (C_in)
         | 
| 534 | 
            +
                        dim (`int`, *optional*, defaults to 2048):
         | 
| 535 | 
            +
                            Hidden dimension of the transformer
         | 
| 536 | 
            +
                        ffn_dim (`int`, *optional*, defaults to 8192):
         | 
| 537 | 
            +
                            Intermediate dimension in feed-forward network
         | 
| 538 | 
            +
                        freq_dim (`int`, *optional*, defaults to 256):
         | 
| 539 | 
            +
                            Dimension for sinusoidal time embeddings
         | 
| 540 | 
            +
                        text_dim (`int`, *optional*, defaults to 4096):
         | 
| 541 | 
            +
                            Input dimension for text embeddings
         | 
| 542 | 
            +
                        out_dim (`int`, *optional*, defaults to 16):
         | 
| 543 | 
            +
                            Output video channels (C_out)
         | 
| 544 | 
            +
                        num_heads (`int`, *optional*, defaults to 16):
         | 
| 545 | 
            +
                            Number of attention heads
         | 
| 546 | 
            +
                        num_layers (`int`, *optional*, defaults to 32):
         | 
| 547 | 
            +
                            Number of transformer blocks
         | 
| 548 | 
            +
                        window_size (`tuple`, *optional*, defaults to (-1, -1)):
         | 
| 549 | 
            +
                            Window size for local attention (-1 indicates global attention)
         | 
| 550 | 
            +
                        qk_norm (`bool`, *optional*, defaults to True):
         | 
| 551 | 
            +
                            Enable query/key normalization
         | 
| 552 | 
            +
                        cross_attn_norm (`bool`, *optional*, defaults to False):
         | 
| 553 | 
            +
                            Enable cross-attention normalization
         | 
| 554 | 
            +
                        eps (`float`, *optional*, defaults to 1e-6):
         | 
| 555 | 
            +
                            Epsilon value for normalization layers
         | 
| 556 | 
            +
                    """
         | 
| 557 | 
            +
             | 
| 558 | 
            +
                    super().__init__()
         | 
| 559 | 
            +
             | 
| 560 | 
            +
                    assert model_type in ['t2v', 'i2v']
         | 
| 561 | 
            +
                    self.model_type = model_type
         | 
| 562 | 
            +
             | 
| 563 | 
            +
                    self.patch_size = patch_size
         | 
| 564 | 
            +
                    self.text_len = text_len
         | 
| 565 | 
            +
                    self.in_dim = in_dim
         | 
| 566 | 
            +
                    self.dim = dim
         | 
| 567 | 
            +
                    self.ffn_dim = ffn_dim
         | 
| 568 | 
            +
                    self.freq_dim = freq_dim
         | 
| 569 | 
            +
                    self.text_dim = text_dim
         | 
| 570 | 
            +
                    self.out_dim = out_dim
         | 
| 571 | 
            +
                    self.num_heads = num_heads
         | 
| 572 | 
            +
                    self.num_layers = num_layers
         | 
| 573 | 
            +
                    self.window_size = window_size
         | 
| 574 | 
            +
                    self.qk_norm = qk_norm
         | 
| 575 | 
            +
                    self.cross_attn_norm = cross_attn_norm
         | 
| 576 | 
            +
                    self.eps = eps
         | 
| 577 | 
            +
             | 
| 578 | 
            +
                    # embeddings
         | 
| 579 | 
            +
                    self.patch_embedding = nn.Conv3d(
         | 
| 580 | 
            +
                        in_dim, dim, kernel_size=patch_size, stride=patch_size)
         | 
| 581 | 
            +
                    self.text_embedding = nn.Sequential(
         | 
| 582 | 
            +
                        nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
         | 
| 583 | 
            +
                        nn.Linear(dim, dim))
         | 
| 584 | 
            +
             | 
| 585 | 
            +
                    self.time_embedding = nn.Sequential(
         | 
| 586 | 
            +
                        nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
         | 
| 587 | 
            +
                    self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
         | 
| 588 | 
            +
             | 
| 589 | 
            +
                    # blocks
         | 
| 590 | 
            +
                    cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
         | 
| 591 | 
            +
                    self.insert_audio = insert_audio
         | 
| 592 | 
            +
                    self.blocks = nn.ModuleList([
         | 
| 593 | 
            +
                        WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
         | 
| 594 | 
            +
                                    window_size, qk_norm, cross_attn_norm,
         | 
| 595 | 
            +
                                    eps, use_audio=self.insert_audio)
         | 
| 596 | 
            +
                        for _ in range(num_layers)
         | 
| 597 | 
            +
                    ])
         | 
| 598 | 
            +
             | 
| 599 | 
            +
                    # head
         | 
| 600 | 
            +
                    self.head = Head(dim, out_dim, patch_size, eps)
         | 
| 601 | 
            +
             | 
| 602 | 
            +
                    if self.insert_audio:
         | 
| 603 | 
            +
                        self.audio_proj = AudioProjModel(seq_len=8, blocks=5, channels=1280, 
         | 
| 604 | 
            +
                            intermediate_dim=512, output_dim=1536, context_tokens=audio_token_num)
         | 
| 605 | 
            +
             | 
| 606 | 
            +
                    # RoPE freqs: register as a buffer so it moves with .to() / DDP and is tracked by compile
         | 
| 607 | 
            +
                    assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
         | 
| 608 | 
            +
                    d = dim // num_heads
         | 
| 609 | 
            +
             | 
| 610 | 
            +
                    _freqs = torch.cat([
         | 
| 611 | 
            +
                        rope_params(1024, d - 4 * (d // 6)),
         | 
| 612 | 
            +
                        rope_params(1024, 2 * (d // 6)),
         | 
| 613 | 
            +
                        rope_params(1024, 2 * (d // 6))
         | 
| 614 | 
            +
                    ], dim=1)
         | 
| 615 | 
            +
                    self.register_buffer("freqs", _freqs, persistent=False)
         | 
| 616 | 
            +
             | 
| 617 | 
            +
                    # initialize weights
         | 
| 618 | 
            +
                    self.init_weights()
         | 
| 619 | 
            +
             | 
| 620 | 
            +
                    # initialize unified parallel
         | 
| 621 | 
            +
                    if is_unified_parallel_initialized():
         | 
| 622 | 
            +
                        print(f"Initializing WanModel with unified parallel initialized")
         | 
| 623 | 
            +
                        from humo.models.distributed.dit_ulysses_sequence_parallel import ulysses_attn_forward, ulysses_dit_forward, ulysses_audio_cross_attn_forward
         | 
| 624 | 
            +
                        for block in self.blocks:
         | 
| 625 | 
            +
                            block.self_attn.forward = types.MethodType(ulysses_attn_forward, block.self_attn)
         | 
| 626 | 
            +
                            if block.use_audio:
         | 
| 627 | 
            +
                                block.audio_cross_attn_wrapper.audio_cross_attn.forward = types.MethodType(ulysses_audio_cross_attn_forward, block.audio_cross_attn_wrapper.audio_cross_attn)
         | 
| 628 | 
            +
                        self.forward = types.MethodType(ulysses_dit_forward, self)
         | 
| 629 | 
            +
                    
         | 
| 630 | 
            +
                def forward(
         | 
| 631 | 
            +
                    self,
         | 
| 632 | 
            +
                    x,
         | 
| 633 | 
            +
                    t,
         | 
| 634 | 
            +
                    context,
         | 
| 635 | 
            +
                    seq_len,
         | 
| 636 | 
            +
                    audio=None,
         | 
| 637 | 
            +
                    y=None,
         | 
| 638 | 
            +
                    tea_cache=None,
         | 
| 639 | 
            +
                ):
         | 
| 640 | 
            +
                    r"""
         | 
| 641 | 
            +
                    Forward pass through the diffusion model
         | 
| 642 | 
            +
             | 
| 643 | 
            +
                    Args:
         | 
| 644 | 
            +
                        x (List[Tensor]):
         | 
| 645 | 
            +
                            List of input video tensors, each with shape [C_in, F, H, W]
         | 
| 646 | 
            +
                        t (Tensor):
         | 
| 647 | 
            +
                            Diffusion timesteps tensor of shape [B]
         | 
| 648 | 
            +
                        context (List[Tensor]):
         | 
| 649 | 
            +
                            List of text embeddings each with shape [L, C]
         | 
| 650 | 
            +
                        seq_len (`int`):
         | 
| 651 | 
            +
                            Maximum sequence length for positional encoding
         | 
| 652 | 
            +
                        clip_fea (Tensor, *optional*):
         | 
| 653 | 
            +
                            CLIP image features for image-to-video mode
         | 
| 654 | 
            +
                        y (List[Tensor], *optional*):
         | 
| 655 | 
            +
                            Conditional video inputs for image-to-video mode, same shape as x
         | 
| 656 | 
            +
             | 
| 657 | 
            +
                    Returns:
         | 
| 658 | 
            +
                        List[Tensor]:
         | 
| 659 | 
            +
                            List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
         | 
| 660 | 
            +
                    """
         | 
| 661 | 
            +
                    if self.model_type == 'i2v':
         | 
| 662 | 
            +
                        assert y is not None
         | 
| 663 | 
            +
             | 
| 664 | 
            +
                    # params
         | 
| 665 | 
            +
                    freqs = self.freqs
         | 
| 666 | 
            +
             | 
| 667 | 
            +
                    if y is not None:
         | 
| 668 | 
            +
                        x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
         | 
| 669 | 
            +
             | 
| 670 | 
            +
                    # embeddings
         | 
| 671 | 
            +
                    x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
         | 
| 672 | 
            +
                    grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
         | 
| 673 | 
            +
                    
         | 
| 674 | 
            +
                    x = [u.flatten(2).transpose(1, 2) for u in x]
         | 
| 675 | 
            +
                    seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
         | 
| 676 | 
            +
                    assert seq_lens.max() <= seq_len
         | 
| 677 | 
            +
             | 
| 678 | 
            +
                    # pad to uniform length and batch
         | 
| 679 | 
            +
                    x = torch.cat([
         | 
| 680 | 
            +
                        torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
         | 
| 681 | 
            +
                        for u in x
         | 
| 682 | 
            +
                    ])  # shape: [B, seq_len, C]
         | 
| 683 | 
            +
             | 
| 684 | 
            +
                    # time embeddings
         | 
| 685 | 
            +
                    with torch.amp.autocast('cuda', dtype=torch.float32):
         | 
| 686 | 
            +
                        e = self.time_embedding(
         | 
| 687 | 
            +
                            sinusoidal_embedding_1d(self.freq_dim, t).float()
         | 
| 688 | 
            +
                        ).float()
         | 
| 689 | 
            +
                        e0 = self.time_projection(e).unflatten(1, (6, self.dim)).float()
         | 
| 690 | 
            +
                        assert e.dtype == torch.float32 and e0.dtype == torch.float32
         | 
| 691 | 
            +
             | 
| 692 | 
            +
                    # context
         | 
| 693 | 
            +
                    context_lens = None
         | 
| 694 | 
            +
                    context = self.text_embedding(
         | 
| 695 | 
            +
                        torch.stack([
         | 
| 696 | 
            +
                            torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
         | 
| 697 | 
            +
                            for u in context
         | 
| 698 | 
            +
                        ])
         | 
| 699 | 
            +
                    )
         | 
| 700 | 
            +
             | 
| 701 | 
            +
                    # audio (unchanged; not cached)
         | 
| 702 | 
            +
                    if self.insert_audio:
         | 
| 703 | 
            +
                        audio = [self.audio_proj(au.unsqueeze(0)).permute(0, 3, 1, 2) for au in audio]
         | 
| 704 | 
            +
                        audio_seq_len = max(au.shape[2] for au in audio) * audio[0].shape[3]
         | 
| 705 | 
            +
             | 
| 706 | 
            +
                        audio = [au.flatten(2).transpose(1, 2) for au in audio]  # [1, t*32, 1536]
         | 
| 707 | 
            +
                        audio = torch.cat([
         | 
| 708 | 
            +
                            torch.cat([au, au.new_zeros(1, int(audio_seq_len) - au.size(1), au.size(2))], dim=1)
         | 
| 709 | 
            +
                            for au in audio
         | 
| 710 | 
            +
                        ])
         | 
| 711 | 
            +
                    else:
         | 
| 712 | 
            +
                        audio = None
         | 
| 713 | 
            +
                        audio_seq_len = None
         | 
| 714 | 
            +
             | 
| 715 | 
            +
                    # ---- tea_cache integration (mirrors your working model) ----
         | 
| 716 | 
            +
                    if tea_cache is not None:
         | 
| 717 | 
            +
                        # Use the pre-block tokens 'x' and time-mod 'e0' to decide whether to reuse cache
         | 
| 718 | 
            +
                        tea_cache_update = tea_cache.check(self, x, e0)
         | 
| 719 | 
            +
                    else:
         | 
| 720 | 
            +
                        tea_cache_update = False
         | 
| 721 | 
            +
             | 
| 722 | 
            +
                    ori_x_len = x.shape[1]  # remember original token length before potential cache extension
         | 
| 723 | 
            +
             | 
| 724 | 
            +
                    if tea_cache_update:
         | 
| 725 | 
            +
                        # Let the cache inject/append any needed past states/tokens for reuse
         | 
| 726 | 
            +
                        x = tea_cache.update(x)
         | 
| 727 | 
            +
                    else:
         | 
| 728 | 
            +
                        # arguments for blocks
         | 
| 729 | 
            +
                        kwargs = dict(
         | 
| 730 | 
            +
                            e=e0,
         | 
| 731 | 
            +
                            seq_lens=seq_lens,
         | 
| 732 | 
            +
                            grid_sizes=grid_sizes,
         | 
| 733 | 
            +
                            freqs=freqs,
         | 
| 734 | 
            +
                            context=context,
         | 
| 735 | 
            +
                            context_lens=context_lens,
         | 
| 736 | 
            +
                            audio=audio,
         | 
| 737 | 
            +
                            audio_seq_len=audio_seq_len
         | 
| 738 | 
            +
                        )
         | 
| 739 | 
            +
             | 
| 740 | 
            +
                        # transformer blocks
         | 
| 741 | 
            +
                        for block in self.blocks:
         | 
| 742 | 
            +
                            x = block(x, **kwargs)
         | 
| 743 | 
            +
             | 
| 744 | 
            +
                        if tea_cache is not None:
         | 
| 745 | 
            +
                            x_cache = x[:, :ori_x_len]
         | 
| 746 | 
            +
                            tea_cache.store(x_cache)
         | 
| 747 | 
            +
             | 
| 748 | 
            +
                    # head
         | 
| 749 | 
            +
                    x = self.head(x, e)
         | 
| 750 | 
            +
             | 
| 751 | 
            +
                    # unpatchify
         | 
| 752 | 
            +
                    x = self.unpatchify(x, grid_sizes)
         | 
| 753 | 
            +
                    return [u.float() for u in x]
         | 
| 754 | 
            +
             | 
| 755 | 
            +
             | 
| 756 | 
            +
                def unpatchify(self, x, grid_sizes):
         | 
| 757 | 
            +
                    r"""
         | 
| 758 | 
            +
                    Reconstruct video tensors from patch embeddings.
         | 
| 759 | 
            +
             | 
| 760 | 
            +
                    Args:
         | 
| 761 | 
            +
                        x (List[Tensor]):
         | 
| 762 | 
            +
                            List of patchified features, each with shape [L, C_out * prod(patch_size)]
         | 
| 763 | 
            +
                        grid_sizes (Tensor):
         | 
| 764 | 
            +
                            Original spatial-temporal grid dimensions before patching,
         | 
| 765 | 
            +
                                shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
         | 
| 766 | 
            +
             | 
| 767 | 
            +
                    Returns:
         | 
| 768 | 
            +
                        List[Tensor]:
         | 
| 769 | 
            +
                            Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
         | 
| 770 | 
            +
                    """
         | 
| 771 | 
            +
             | 
| 772 | 
            +
                    c = self.out_dim
         | 
| 773 | 
            +
                    out = []
         | 
| 774 | 
            +
                    for u, v in zip(x, grid_sizes.tolist()):
         | 
| 775 | 
            +
                        u = u[:math.prod(v)].view(*v, *self.patch_size, c)
         | 
| 776 | 
            +
                        u = torch.einsum('fhwpqrc->cfphqwr', u)
         | 
| 777 | 
            +
                        u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
         | 
| 778 | 
            +
                        out.append(u)
         | 
| 779 | 
            +
                    return out
         | 
| 780 | 
            +
             | 
| 781 | 
            +
                def init_weights(self):
         | 
| 782 | 
            +
                    r"""
         | 
| 783 | 
            +
                    Initialize model parameters using Xavier initialization.
         | 
| 784 | 
            +
                    """
         | 
| 785 | 
            +
             | 
| 786 | 
            +
                    # basic init
         | 
| 787 | 
            +
                    for m in self.modules():
         | 
| 788 | 
            +
                        if isinstance(m, nn.Linear):
         | 
| 789 | 
            +
                            nn.init.xavier_uniform_(m.weight)
         | 
| 790 | 
            +
                            if m.bias is not None:
         | 
| 791 | 
            +
                                nn.init.zeros_(m.bias)
         | 
| 792 | 
            +
             | 
| 793 | 
            +
                    # init embeddings
         | 
| 794 | 
            +
                    nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
         | 
| 795 | 
            +
                    for m in self.text_embedding.modules():
         | 
| 796 | 
            +
                        if isinstance(m, nn.Linear):
         | 
| 797 | 
            +
                            nn.init.normal_(m.weight, std=.02)
         | 
| 798 | 
            +
                    for m in self.time_embedding.modules():
         | 
| 799 | 
            +
                        if isinstance(m, nn.Linear):
         | 
| 800 | 
            +
                            nn.init.normal_(m.weight, std=.02)
         | 
| 801 | 
            +
             | 
| 802 | 
            +
                    # init output layer
         | 
| 803 | 
            +
                    nn.init.zeros_(self.head.head.weight)
         | 
    	
        humo/models/wan_modules/t5.py
    ADDED
    
    | @@ -0,0 +1,525 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Modified from transformers.models.t5.modeling_t5
         | 
| 2 | 
            +
            # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
         | 
| 3 | 
            +
            import logging
         | 
| 4 | 
            +
            import math
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import torch
         | 
| 7 | 
            +
            import torch.nn as nn
         | 
| 8 | 
            +
            import torch.nn.functional as F
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from .tokenizers import HuggingfaceTokenizer
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            __all__ = [
         | 
| 13 | 
            +
                'T5Model',
         | 
| 14 | 
            +
                'T5Encoder',
         | 
| 15 | 
            +
                'T5Decoder',
         | 
| 16 | 
            +
                'T5EncoderModel',
         | 
| 17 | 
            +
            ]
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            def fp16_clamp(x):
         | 
| 21 | 
            +
                if x.dtype == torch.float16 and torch.isinf(x).any():
         | 
| 22 | 
            +
                    clamp = torch.finfo(x.dtype).max - 1000
         | 
| 23 | 
            +
                    x = torch.clamp(x, min=-clamp, max=clamp)
         | 
| 24 | 
            +
                return x
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            def init_weights(m):
         | 
| 28 | 
            +
                if isinstance(m, T5LayerNorm):
         | 
| 29 | 
            +
                    nn.init.ones_(m.weight)
         | 
| 30 | 
            +
                elif isinstance(m, T5Model):
         | 
| 31 | 
            +
                    nn.init.normal_(m.token_embedding.weight, std=1.0)
         | 
| 32 | 
            +
                elif isinstance(m, T5FeedForward):
         | 
| 33 | 
            +
                    nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
         | 
| 34 | 
            +
                    nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
         | 
| 35 | 
            +
                    nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
         | 
| 36 | 
            +
                elif isinstance(m, T5Attention):
         | 
| 37 | 
            +
                    nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5)
         | 
| 38 | 
            +
                    nn.init.normal_(m.k.weight, std=m.dim**-0.5)
         | 
| 39 | 
            +
                    nn.init.normal_(m.v.weight, std=m.dim**-0.5)
         | 
| 40 | 
            +
                    nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5)
         | 
| 41 | 
            +
                elif isinstance(m, T5RelativeEmbedding):
         | 
| 42 | 
            +
                    nn.init.normal_(
         | 
| 43 | 
            +
                        m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5)
         | 
| 44 | 
            +
             | 
| 45 | 
            +
             | 
| 46 | 
            +
            class GELU(nn.Module):
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                def forward(self, x):
         | 
| 49 | 
            +
                    return 0.5 * x * (1.0 + torch.tanh(
         | 
| 50 | 
            +
                        math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
         | 
| 51 | 
            +
             | 
| 52 | 
            +
             | 
| 53 | 
            +
            class T5LayerNorm(nn.Module):
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                def __init__(self, dim, eps=1e-6):
         | 
| 56 | 
            +
                    super(T5LayerNorm, self).__init__()
         | 
| 57 | 
            +
                    self.dim = dim
         | 
| 58 | 
            +
                    self.eps = eps
         | 
| 59 | 
            +
                    self.weight = nn.Parameter(torch.ones(dim))
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                def forward(self, x):
         | 
| 62 | 
            +
                    x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +
         | 
| 63 | 
            +
                                        self.eps)
         | 
| 64 | 
            +
                    if self.weight.dtype in [torch.float16, torch.bfloat16]:
         | 
| 65 | 
            +
                        x = x.type_as(self.weight)
         | 
| 66 | 
            +
                    return self.weight * x
         | 
| 67 | 
            +
             | 
| 68 | 
            +
             | 
| 69 | 
            +
            class T5Attention(nn.Module):
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
         | 
| 72 | 
            +
                    assert dim_attn % num_heads == 0
         | 
| 73 | 
            +
                    super(T5Attention, self).__init__()
         | 
| 74 | 
            +
                    self.dim = dim
         | 
| 75 | 
            +
                    self.dim_attn = dim_attn
         | 
| 76 | 
            +
                    self.num_heads = num_heads
         | 
| 77 | 
            +
                    self.head_dim = dim_attn // num_heads
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                    # layers
         | 
| 80 | 
            +
                    self.q = nn.Linear(dim, dim_attn, bias=False)
         | 
| 81 | 
            +
                    self.k = nn.Linear(dim, dim_attn, bias=False)
         | 
| 82 | 
            +
                    self.v = nn.Linear(dim, dim_attn, bias=False)
         | 
| 83 | 
            +
                    self.o = nn.Linear(dim_attn, dim, bias=False)
         | 
| 84 | 
            +
                    self.dropout = nn.Dropout(dropout)
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                def forward(self, x, context=None, mask=None, pos_bias=None):
         | 
| 87 | 
            +
                    """
         | 
| 88 | 
            +
                    x:          [B, L1, C].
         | 
| 89 | 
            +
                    context:    [B, L2, C] or None.
         | 
| 90 | 
            +
                    mask:       [B, L2] or [B, L1, L2] or None.
         | 
| 91 | 
            +
                    """
         | 
| 92 | 
            +
                    # check inputs
         | 
| 93 | 
            +
                    context = x if context is None else context
         | 
| 94 | 
            +
                    b, n, c = x.size(0), self.num_heads, self.head_dim
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                    # compute query, key, value
         | 
| 97 | 
            +
                    q = self.q(x).view(b, -1, n, c)
         | 
| 98 | 
            +
                    k = self.k(context).view(b, -1, n, c)
         | 
| 99 | 
            +
                    v = self.v(context).view(b, -1, n, c)
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                    # attention bias
         | 
| 102 | 
            +
                    attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
         | 
| 103 | 
            +
                    if pos_bias is not None:
         | 
| 104 | 
            +
                        attn_bias += pos_bias
         | 
| 105 | 
            +
                    if mask is not None:
         | 
| 106 | 
            +
                        assert mask.ndim in [2, 3]
         | 
| 107 | 
            +
                        mask = mask.view(b, 1, 1,
         | 
| 108 | 
            +
                                         -1) if mask.ndim == 2 else mask.unsqueeze(1)
         | 
| 109 | 
            +
                        attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                    # compute attention (T5 does not use scaling)
         | 
| 112 | 
            +
                    attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias
         | 
| 113 | 
            +
                    attn = F.softmax(attn.float(), dim=-1).type_as(attn)
         | 
| 114 | 
            +
                    x = torch.einsum('bnij,bjnc->binc', attn, v)
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                    # output
         | 
| 117 | 
            +
                    x = x.reshape(b, -1, n * c)
         | 
| 118 | 
            +
                    x = self.o(x)
         | 
| 119 | 
            +
                    x = self.dropout(x)
         | 
| 120 | 
            +
                    return x
         | 
| 121 | 
            +
             | 
| 122 | 
            +
             | 
| 123 | 
            +
            class T5FeedForward(nn.Module):
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                def __init__(self, dim, dim_ffn, dropout=0.1):
         | 
| 126 | 
            +
                    super(T5FeedForward, self).__init__()
         | 
| 127 | 
            +
                    self.dim = dim
         | 
| 128 | 
            +
                    self.dim_ffn = dim_ffn
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                    # layers
         | 
| 131 | 
            +
                    self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
         | 
| 132 | 
            +
                    self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
         | 
| 133 | 
            +
                    self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
         | 
| 134 | 
            +
                    self.dropout = nn.Dropout(dropout)
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                def forward(self, x):
         | 
| 137 | 
            +
                    x = self.fc1(x) * self.gate(x)
         | 
| 138 | 
            +
                    x = self.dropout(x)
         | 
| 139 | 
            +
                    x = self.fc2(x)
         | 
| 140 | 
            +
                    x = self.dropout(x)
         | 
| 141 | 
            +
                    return x
         | 
| 142 | 
            +
             | 
| 143 | 
            +
             | 
| 144 | 
            +
            class T5SelfAttention(nn.Module):
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                def __init__(self,
         | 
| 147 | 
            +
                             dim,
         | 
| 148 | 
            +
                             dim_attn,
         | 
| 149 | 
            +
                             dim_ffn,
         | 
| 150 | 
            +
                             num_heads,
         | 
| 151 | 
            +
                             num_buckets,
         | 
| 152 | 
            +
                             shared_pos=True,
         | 
| 153 | 
            +
                             dropout=0.1):
         | 
| 154 | 
            +
                    super(T5SelfAttention, self).__init__()
         | 
| 155 | 
            +
                    self.dim = dim
         | 
| 156 | 
            +
                    self.dim_attn = dim_attn
         | 
| 157 | 
            +
                    self.dim_ffn = dim_ffn
         | 
| 158 | 
            +
                    self.num_heads = num_heads
         | 
| 159 | 
            +
                    self.num_buckets = num_buckets
         | 
| 160 | 
            +
                    self.shared_pos = shared_pos
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                    # layers
         | 
| 163 | 
            +
                    self.norm1 = T5LayerNorm(dim)
         | 
| 164 | 
            +
                    self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
         | 
| 165 | 
            +
                    self.norm2 = T5LayerNorm(dim)
         | 
| 166 | 
            +
                    self.ffn = T5FeedForward(dim, dim_ffn, dropout)
         | 
| 167 | 
            +
                    self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
         | 
| 168 | 
            +
                        num_buckets, num_heads, bidirectional=True)
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                def forward(self, x, mask=None, pos_bias=None):
         | 
| 171 | 
            +
                    e = pos_bias if self.shared_pos else self.pos_embedding(
         | 
| 172 | 
            +
                        x.size(1), x.size(1))
         | 
| 173 | 
            +
                    x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
         | 
| 174 | 
            +
                    x = fp16_clamp(x + self.ffn(self.norm2(x)))
         | 
| 175 | 
            +
                    return x
         | 
| 176 | 
            +
             | 
| 177 | 
            +
             | 
| 178 | 
            +
            class T5CrossAttention(nn.Module):
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                def __init__(self,
         | 
| 181 | 
            +
                             dim,
         | 
| 182 | 
            +
                             dim_attn,
         | 
| 183 | 
            +
                             dim_ffn,
         | 
| 184 | 
            +
                             num_heads,
         | 
| 185 | 
            +
                             num_buckets,
         | 
| 186 | 
            +
                             shared_pos=True,
         | 
| 187 | 
            +
                             dropout=0.1):
         | 
| 188 | 
            +
                    super(T5CrossAttention, self).__init__()
         | 
| 189 | 
            +
                    self.dim = dim
         | 
| 190 | 
            +
                    self.dim_attn = dim_attn
         | 
| 191 | 
            +
                    self.dim_ffn = dim_ffn
         | 
| 192 | 
            +
                    self.num_heads = num_heads
         | 
| 193 | 
            +
                    self.num_buckets = num_buckets
         | 
| 194 | 
            +
                    self.shared_pos = shared_pos
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                    # layers
         | 
| 197 | 
            +
                    self.norm1 = T5LayerNorm(dim)
         | 
| 198 | 
            +
                    self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
         | 
| 199 | 
            +
                    self.norm2 = T5LayerNorm(dim)
         | 
| 200 | 
            +
                    self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
         | 
| 201 | 
            +
                    self.norm3 = T5LayerNorm(dim)
         | 
| 202 | 
            +
                    self.ffn = T5FeedForward(dim, dim_ffn, dropout)
         | 
| 203 | 
            +
                    self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
         | 
| 204 | 
            +
                        num_buckets, num_heads, bidirectional=False)
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                def forward(self,
         | 
| 207 | 
            +
                            x,
         | 
| 208 | 
            +
                            mask=None,
         | 
| 209 | 
            +
                            encoder_states=None,
         | 
| 210 | 
            +
                            encoder_mask=None,
         | 
| 211 | 
            +
                            pos_bias=None):
         | 
| 212 | 
            +
                    e = pos_bias if self.shared_pos else self.pos_embedding(
         | 
| 213 | 
            +
                        x.size(1), x.size(1))
         | 
| 214 | 
            +
                    x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
         | 
| 215 | 
            +
                    x = fp16_clamp(x + self.cross_attn(
         | 
| 216 | 
            +
                        self.norm2(x), context=encoder_states, mask=encoder_mask))
         | 
| 217 | 
            +
                    x = fp16_clamp(x + self.ffn(self.norm3(x)))
         | 
| 218 | 
            +
                    return x
         | 
| 219 | 
            +
             | 
| 220 | 
            +
             | 
| 221 | 
            +
            class T5RelativeEmbedding(nn.Module):
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
         | 
| 224 | 
            +
                    super(T5RelativeEmbedding, self).__init__()
         | 
| 225 | 
            +
                    self.num_buckets = num_buckets
         | 
| 226 | 
            +
                    self.num_heads = num_heads
         | 
| 227 | 
            +
                    self.bidirectional = bidirectional
         | 
| 228 | 
            +
                    self.max_dist = max_dist
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                    # layers
         | 
| 231 | 
            +
                    self.embedding = nn.Embedding(num_buckets, num_heads)
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                def forward(self, lq, lk):
         | 
| 234 | 
            +
                    device = self.embedding.weight.device
         | 
| 235 | 
            +
                    # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
         | 
| 236 | 
            +
                    #     torch.arange(lq).unsqueeze(1).to(device)
         | 
| 237 | 
            +
                    rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \
         | 
| 238 | 
            +
                        torch.arange(lq, device=device).unsqueeze(1)
         | 
| 239 | 
            +
                    rel_pos = self._relative_position_bucket(rel_pos)
         | 
| 240 | 
            +
                    rel_pos_embeds = self.embedding(rel_pos)
         | 
| 241 | 
            +
                    rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(
         | 
| 242 | 
            +
                        0)  # [1, N, Lq, Lk]
         | 
| 243 | 
            +
                    return rel_pos_embeds.contiguous()
         | 
| 244 | 
            +
             | 
| 245 | 
            +
                def _relative_position_bucket(self, rel_pos):
         | 
| 246 | 
            +
                    # preprocess
         | 
| 247 | 
            +
                    if self.bidirectional:
         | 
| 248 | 
            +
                        num_buckets = self.num_buckets // 2
         | 
| 249 | 
            +
                        rel_buckets = (rel_pos > 0).long() * num_buckets
         | 
| 250 | 
            +
                        rel_pos = torch.abs(rel_pos)
         | 
| 251 | 
            +
                    else:
         | 
| 252 | 
            +
                        num_buckets = self.num_buckets
         | 
| 253 | 
            +
                        rel_buckets = 0
         | 
| 254 | 
            +
                        rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                    # embeddings for small and large positions
         | 
| 257 | 
            +
                    max_exact = num_buckets // 2
         | 
| 258 | 
            +
                    rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) /
         | 
| 259 | 
            +
                                                 math.log(self.max_dist / max_exact) *
         | 
| 260 | 
            +
                                                 (num_buckets - max_exact)).long()
         | 
| 261 | 
            +
                    rel_pos_large = torch.min(
         | 
| 262 | 
            +
                        rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
         | 
| 263 | 
            +
                    rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
         | 
| 264 | 
            +
                    return rel_buckets
         | 
| 265 | 
            +
             | 
| 266 | 
            +
             | 
| 267 | 
            +
            class T5Encoder(nn.Module):
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                def __init__(self,
         | 
| 270 | 
            +
                             vocab,
         | 
| 271 | 
            +
                             dim,
         | 
| 272 | 
            +
                             dim_attn,
         | 
| 273 | 
            +
                             dim_ffn,
         | 
| 274 | 
            +
                             num_heads,
         | 
| 275 | 
            +
                             num_layers,
         | 
| 276 | 
            +
                             num_buckets,
         | 
| 277 | 
            +
                             shared_pos=True,
         | 
| 278 | 
            +
                             dropout=0.1):
         | 
| 279 | 
            +
                    super(T5Encoder, self).__init__()
         | 
| 280 | 
            +
                    self.dim = dim
         | 
| 281 | 
            +
                    self.dim_attn = dim_attn
         | 
| 282 | 
            +
                    self.dim_ffn = dim_ffn
         | 
| 283 | 
            +
                    self.num_heads = num_heads
         | 
| 284 | 
            +
                    self.num_layers = num_layers
         | 
| 285 | 
            +
                    self.num_buckets = num_buckets
         | 
| 286 | 
            +
                    self.shared_pos = shared_pos
         | 
| 287 | 
            +
             | 
| 288 | 
            +
                    # layers
         | 
| 289 | 
            +
                    self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
         | 
| 290 | 
            +
                        else nn.Embedding(vocab, dim)
         | 
| 291 | 
            +
                    self.pos_embedding = T5RelativeEmbedding(
         | 
| 292 | 
            +
                        num_buckets, num_heads, bidirectional=True) if shared_pos else None
         | 
| 293 | 
            +
                    self.dropout = nn.Dropout(dropout)
         | 
| 294 | 
            +
                    self.blocks = nn.ModuleList([
         | 
| 295 | 
            +
                        T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
         | 
| 296 | 
            +
                                        shared_pos, dropout) for _ in range(num_layers)
         | 
| 297 | 
            +
                    ])
         | 
| 298 | 
            +
                    self.norm = T5LayerNorm(dim)
         | 
| 299 | 
            +
             | 
| 300 | 
            +
                    # initialize weights
         | 
| 301 | 
            +
                    self.apply(init_weights)
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                def forward(self, ids, mask=None):
         | 
| 304 | 
            +
                    x = self.token_embedding(ids)
         | 
| 305 | 
            +
                    x = self.dropout(x)
         | 
| 306 | 
            +
                    e = self.pos_embedding(x.size(1),
         | 
| 307 | 
            +
                                           x.size(1)) if self.shared_pos else None
         | 
| 308 | 
            +
                    for block in self.blocks:
         | 
| 309 | 
            +
                        x = block(x, mask, pos_bias=e)
         | 
| 310 | 
            +
                    x = self.norm(x)
         | 
| 311 | 
            +
                    x = self.dropout(x)
         | 
| 312 | 
            +
                    return x
         | 
| 313 | 
            +
             | 
| 314 | 
            +
             | 
| 315 | 
            +
            class T5Decoder(nn.Module):
         | 
| 316 | 
            +
             | 
| 317 | 
            +
                def __init__(self,
         | 
| 318 | 
            +
                             vocab,
         | 
| 319 | 
            +
                             dim,
         | 
| 320 | 
            +
                             dim_attn,
         | 
| 321 | 
            +
                             dim_ffn,
         | 
| 322 | 
            +
                             num_heads,
         | 
| 323 | 
            +
                             num_layers,
         | 
| 324 | 
            +
                             num_buckets,
         | 
| 325 | 
            +
                             shared_pos=True,
         | 
| 326 | 
            +
                             dropout=0.1):
         | 
| 327 | 
            +
                    super(T5Decoder, self).__init__()
         | 
| 328 | 
            +
                    self.dim = dim
         | 
| 329 | 
            +
                    self.dim_attn = dim_attn
         | 
| 330 | 
            +
                    self.dim_ffn = dim_ffn
         | 
| 331 | 
            +
                    self.num_heads = num_heads
         | 
| 332 | 
            +
                    self.num_layers = num_layers
         | 
| 333 | 
            +
                    self.num_buckets = num_buckets
         | 
| 334 | 
            +
                    self.shared_pos = shared_pos
         | 
| 335 | 
            +
             | 
| 336 | 
            +
                    # layers
         | 
| 337 | 
            +
                    self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
         | 
| 338 | 
            +
                        else nn.Embedding(vocab, dim)
         | 
| 339 | 
            +
                    self.pos_embedding = T5RelativeEmbedding(
         | 
| 340 | 
            +
                        num_buckets, num_heads, bidirectional=False) if shared_pos else None
         | 
| 341 | 
            +
                    self.dropout = nn.Dropout(dropout)
         | 
| 342 | 
            +
                    self.blocks = nn.ModuleList([
         | 
| 343 | 
            +
                        T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
         | 
| 344 | 
            +
                                         shared_pos, dropout) for _ in range(num_layers)
         | 
| 345 | 
            +
                    ])
         | 
| 346 | 
            +
                    self.norm = T5LayerNorm(dim)
         | 
| 347 | 
            +
             | 
| 348 | 
            +
                    # initialize weights
         | 
| 349 | 
            +
                    self.apply(init_weights)
         | 
| 350 | 
            +
             | 
| 351 | 
            +
                def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None):
         | 
| 352 | 
            +
                    b, s = ids.size()
         | 
| 353 | 
            +
             | 
| 354 | 
            +
                    # causal mask
         | 
| 355 | 
            +
                    if mask is None:
         | 
| 356 | 
            +
                        mask = torch.tril(torch.ones(1, s, s).to(ids.device))
         | 
| 357 | 
            +
                    elif mask.ndim == 2:
         | 
| 358 | 
            +
                        mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1))
         | 
| 359 | 
            +
             | 
| 360 | 
            +
                    # layers
         | 
| 361 | 
            +
                    x = self.token_embedding(ids)
         | 
| 362 | 
            +
                    x = self.dropout(x)
         | 
| 363 | 
            +
                    e = self.pos_embedding(x.size(1),
         | 
| 364 | 
            +
                                           x.size(1)) if self.shared_pos else None
         | 
| 365 | 
            +
                    for block in self.blocks:
         | 
| 366 | 
            +
                        x = block(x, mask, encoder_states, encoder_mask, pos_bias=e)
         | 
| 367 | 
            +
                    x = self.norm(x)
         | 
| 368 | 
            +
                    x = self.dropout(x)
         | 
| 369 | 
            +
                    return x
         | 
| 370 | 
            +
             | 
| 371 | 
            +
             | 
| 372 | 
            +
            class T5Model(nn.Module):
         | 
| 373 | 
            +
             | 
| 374 | 
            +
                def __init__(self,
         | 
| 375 | 
            +
                             vocab_size,
         | 
| 376 | 
            +
                             dim,
         | 
| 377 | 
            +
                             dim_attn,
         | 
| 378 | 
            +
                             dim_ffn,
         | 
| 379 | 
            +
                             num_heads,
         | 
| 380 | 
            +
                             encoder_layers,
         | 
| 381 | 
            +
                             decoder_layers,
         | 
| 382 | 
            +
                             num_buckets,
         | 
| 383 | 
            +
                             shared_pos=True,
         | 
| 384 | 
            +
                             dropout=0.1):
         | 
| 385 | 
            +
                    super(T5Model, self).__init__()
         | 
| 386 | 
            +
                    self.vocab_size = vocab_size
         | 
| 387 | 
            +
                    self.dim = dim
         | 
| 388 | 
            +
                    self.dim_attn = dim_attn
         | 
| 389 | 
            +
                    self.dim_ffn = dim_ffn
         | 
| 390 | 
            +
                    self.num_heads = num_heads
         | 
| 391 | 
            +
                    self.encoder_layers = encoder_layers
         | 
| 392 | 
            +
                    self.decoder_layers = decoder_layers
         | 
| 393 | 
            +
                    self.num_buckets = num_buckets
         | 
| 394 | 
            +
             | 
| 395 | 
            +
                    # layers
         | 
| 396 | 
            +
                    self.token_embedding = nn.Embedding(vocab_size, dim)
         | 
| 397 | 
            +
                    self.encoder = T5Encoder(self.token_embedding, dim, dim_attn, dim_ffn,
         | 
| 398 | 
            +
                                             num_heads, encoder_layers, num_buckets,
         | 
| 399 | 
            +
                                             shared_pos, dropout)
         | 
| 400 | 
            +
                    self.decoder = T5Decoder(self.token_embedding, dim, dim_attn, dim_ffn,
         | 
| 401 | 
            +
                                             num_heads, decoder_layers, num_buckets,
         | 
| 402 | 
            +
                                             shared_pos, dropout)
         | 
| 403 | 
            +
                    self.head = nn.Linear(dim, vocab_size, bias=False)
         | 
| 404 | 
            +
             | 
| 405 | 
            +
                    # initialize weights
         | 
| 406 | 
            +
                    self.apply(init_weights)
         | 
| 407 | 
            +
             | 
| 408 | 
            +
                def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
         | 
| 409 | 
            +
                    x = self.encoder(encoder_ids, encoder_mask)
         | 
| 410 | 
            +
                    x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask)
         | 
| 411 | 
            +
                    x = self.head(x)
         | 
| 412 | 
            +
                    return x
         | 
| 413 | 
            +
             | 
| 414 | 
            +
             | 
| 415 | 
            +
            def _t5(name,
         | 
| 416 | 
            +
                    encoder_only=False,
         | 
| 417 | 
            +
                    decoder_only=False,
         | 
| 418 | 
            +
                    return_tokenizer=False,
         | 
| 419 | 
            +
                    tokenizer_kwargs={},
         | 
| 420 | 
            +
                    dtype=torch.float32,
         | 
| 421 | 
            +
                    device='cpu',
         | 
| 422 | 
            +
                    **kwargs):
         | 
| 423 | 
            +
                # sanity check
         | 
| 424 | 
            +
                assert not (encoder_only and decoder_only)
         | 
| 425 | 
            +
             | 
| 426 | 
            +
                # params
         | 
| 427 | 
            +
                if encoder_only:
         | 
| 428 | 
            +
                    model_cls = T5Encoder
         | 
| 429 | 
            +
                    kwargs['vocab'] = kwargs.pop('vocab_size')
         | 
| 430 | 
            +
                    kwargs['num_layers'] = kwargs.pop('encoder_layers')
         | 
| 431 | 
            +
                    _ = kwargs.pop('decoder_layers')
         | 
| 432 | 
            +
                elif decoder_only:
         | 
| 433 | 
            +
                    model_cls = T5Decoder
         | 
| 434 | 
            +
                    kwargs['vocab'] = kwargs.pop('vocab_size')
         | 
| 435 | 
            +
                    kwargs['num_layers'] = kwargs.pop('decoder_layers')
         | 
| 436 | 
            +
                    _ = kwargs.pop('encoder_layers')
         | 
| 437 | 
            +
                else:
         | 
| 438 | 
            +
                    model_cls = T5Model
         | 
| 439 | 
            +
             | 
| 440 | 
            +
                # init model
         | 
| 441 | 
            +
                with torch.device(device):
         | 
| 442 | 
            +
                    model = model_cls(**kwargs)
         | 
| 443 | 
            +
             | 
| 444 | 
            +
                # set device
         | 
| 445 | 
            +
                model = model.to(dtype=dtype, device=device)
         | 
| 446 | 
            +
             | 
| 447 | 
            +
                # init tokenizer
         | 
| 448 | 
            +
                if return_tokenizer:
         | 
| 449 | 
            +
                    from .tokenizers import HuggingfaceTokenizer
         | 
| 450 | 
            +
                    tokenizer = HuggingfaceTokenizer(f'google/{name}', **tokenizer_kwargs)
         | 
| 451 | 
            +
                    return model, tokenizer
         | 
| 452 | 
            +
                else:
         | 
| 453 | 
            +
                    return model
         | 
| 454 | 
            +
             | 
| 455 | 
            +
             | 
| 456 | 
            +
            def umt5_xxl(**kwargs):
         | 
| 457 | 
            +
                cfg = dict(
         | 
| 458 | 
            +
                    vocab_size=256384,
         | 
| 459 | 
            +
                    dim=4096,
         | 
| 460 | 
            +
                    dim_attn=4096,
         | 
| 461 | 
            +
                    dim_ffn=10240,
         | 
| 462 | 
            +
                    num_heads=64,
         | 
| 463 | 
            +
                    encoder_layers=24,
         | 
| 464 | 
            +
                    decoder_layers=24,
         | 
| 465 | 
            +
                    num_buckets=32,
         | 
| 466 | 
            +
                    shared_pos=False,
         | 
| 467 | 
            +
                    dropout=0.1)
         | 
| 468 | 
            +
                cfg.update(**kwargs)
         | 
| 469 | 
            +
                return _t5('umt5-xxl', **cfg)
         | 
| 470 | 
            +
             | 
| 471 | 
            +
             | 
| 472 | 
            +
            class T5EncoderModel(nn.Module):
         | 
| 473 | 
            +
             | 
| 474 | 
            +
                def __init__(
         | 
| 475 | 
            +
                    self,
         | 
| 476 | 
            +
                    text_len,
         | 
| 477 | 
            +
                    dtype=torch.bfloat16,
         | 
| 478 | 
            +
                    device=torch.cuda.current_device(),
         | 
| 479 | 
            +
                    checkpoint_path=None,
         | 
| 480 | 
            +
                    tokenizer_path=None,
         | 
| 481 | 
            +
                    shard_fn=None,
         | 
| 482 | 
            +
                ):
         | 
| 483 | 
            +
                    super(T5EncoderModel, self).__init__()
         | 
| 484 | 
            +
                    self.text_len = text_len
         | 
| 485 | 
            +
                    self.dtype = dtype
         | 
| 486 | 
            +
                    self.device = device
         | 
| 487 | 
            +
                    self.checkpoint_path = checkpoint_path
         | 
| 488 | 
            +
                    self.tokenizer_path = tokenizer_path
         | 
| 489 | 
            +
             | 
| 490 | 
            +
                    with torch.device(device):
         | 
| 491 | 
            +
                        self.model = T5Encoder(
         | 
| 492 | 
            +
                            vocab=256384,
         | 
| 493 | 
            +
                            dim=4096,
         | 
| 494 | 
            +
                            dim_attn=4096,
         | 
| 495 | 
            +
                            dim_ffn=10240,
         | 
| 496 | 
            +
                            num_heads=64,
         | 
| 497 | 
            +
                            num_layers=24,
         | 
| 498 | 
            +
                            num_buckets=32,
         | 
| 499 | 
            +
                            shared_pos=False,
         | 
| 500 | 
            +
                            dropout=0.1
         | 
| 501 | 
            +
                        )
         | 
| 502 | 
            +
                    # set device
         | 
| 503 | 
            +
                    self.model = self.model.to(dtype=dtype, device=device).eval().requires_grad_(False)
         | 
| 504 | 
            +
             | 
| 505 | 
            +
                    logging.info(f'loading {checkpoint_path}')
         | 
| 506 | 
            +
                    if checkpoint_path is not None:
         | 
| 507 | 
            +
                        self.model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
         | 
| 508 | 
            +
                    
         | 
| 509 | 
            +
                    if shard_fn is not None:
         | 
| 510 | 
            +
                        self.model = shard_fn(self.model, sync_module_states=False)
         | 
| 511 | 
            +
                    else:
         | 
| 512 | 
            +
                        self.model.to(self.device)
         | 
| 513 | 
            +
                    # init tokenizer
         | 
| 514 | 
            +
                    self.tokenizer = HuggingfaceTokenizer(
         | 
| 515 | 
            +
                        name=tokenizer_path, seq_len=text_len, clean='whitespace')
         | 
| 516 | 
            +
             | 
| 517 | 
            +
                @torch.no_grad()
         | 
| 518 | 
            +
                def __call__(self, texts, device):
         | 
| 519 | 
            +
                    ids, mask = self.tokenizer(
         | 
| 520 | 
            +
                        texts, return_mask=True, add_special_tokens=True)
         | 
| 521 | 
            +
                    ids = ids.to(device)
         | 
| 522 | 
            +
                    mask = mask.to(device)
         | 
| 523 | 
            +
                    seq_lens = mask.gt(0).sum(dim=1).long()
         | 
| 524 | 
            +
                    context = self.model(ids, mask)
         | 
| 525 | 
            +
                    return [u[:v] for u, v in zip(context, seq_lens)]
         | 
    	
        humo/models/wan_modules/tokenizers.py
    ADDED
    
    | @@ -0,0 +1,82 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
         | 
| 2 | 
            +
            import html
         | 
| 3 | 
            +
            import string
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import ftfy
         | 
| 6 | 
            +
            import regex as re
         | 
| 7 | 
            +
            from transformers import AutoTokenizer
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            __all__ = ['HuggingfaceTokenizer']
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            def basic_clean(text):
         | 
| 13 | 
            +
                text = ftfy.fix_text(text)
         | 
| 14 | 
            +
                text = html.unescape(html.unescape(text))
         | 
| 15 | 
            +
                return text.strip()
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            def whitespace_clean(text):
         | 
| 19 | 
            +
                text = re.sub(r'\s+', ' ', text)
         | 
| 20 | 
            +
                text = text.strip()
         | 
| 21 | 
            +
                return text
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            def canonicalize(text, keep_punctuation_exact_string=None):
         | 
| 25 | 
            +
                text = text.replace('_', ' ')
         | 
| 26 | 
            +
                if keep_punctuation_exact_string:
         | 
| 27 | 
            +
                    text = keep_punctuation_exact_string.join(
         | 
| 28 | 
            +
                        part.translate(str.maketrans('', '', string.punctuation))
         | 
| 29 | 
            +
                        for part in text.split(keep_punctuation_exact_string))
         | 
| 30 | 
            +
                else:
         | 
| 31 | 
            +
                    text = text.translate(str.maketrans('', '', string.punctuation))
         | 
| 32 | 
            +
                text = text.lower()
         | 
| 33 | 
            +
                text = re.sub(r'\s+', ' ', text)
         | 
| 34 | 
            +
                return text.strip()
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            class HuggingfaceTokenizer:
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                def __init__(self, name, seq_len=None, clean=None, **kwargs):
         | 
| 40 | 
            +
                    assert clean in (None, 'whitespace', 'lower', 'canonicalize')
         | 
| 41 | 
            +
                    self.name = name
         | 
| 42 | 
            +
                    self.seq_len = seq_len
         | 
| 43 | 
            +
                    self.clean = clean
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                    # init tokenizer
         | 
| 46 | 
            +
                    self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
         | 
| 47 | 
            +
                    self.vocab_size = self.tokenizer.vocab_size
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                def __call__(self, sequence, **kwargs):
         | 
| 50 | 
            +
                    return_mask = kwargs.pop('return_mask', False)
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                    # arguments
         | 
| 53 | 
            +
                    _kwargs = {'return_tensors': 'pt'}
         | 
| 54 | 
            +
                    if self.seq_len is not None:
         | 
| 55 | 
            +
                        _kwargs.update({
         | 
| 56 | 
            +
                            'padding': 'max_length',
         | 
| 57 | 
            +
                            'truncation': True,
         | 
| 58 | 
            +
                            'max_length': self.seq_len
         | 
| 59 | 
            +
                        })
         | 
| 60 | 
            +
                    _kwargs.update(**kwargs)
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                    # tokenization
         | 
| 63 | 
            +
                    if isinstance(sequence, str):
         | 
| 64 | 
            +
                        sequence = [sequence]
         | 
| 65 | 
            +
                    if self.clean:
         | 
| 66 | 
            +
                        sequence = [self._clean(u) for u in sequence]
         | 
| 67 | 
            +
                    ids = self.tokenizer(sequence, **_kwargs)
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                    # output
         | 
| 70 | 
            +
                    if return_mask:
         | 
| 71 | 
            +
                        return ids.input_ids, ids.attention_mask
         | 
| 72 | 
            +
                    else:
         | 
| 73 | 
            +
                        return ids.input_ids
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                def _clean(self, text):
         | 
| 76 | 
            +
                    if self.clean == 'whitespace':
         | 
| 77 | 
            +
                        text = whitespace_clean(basic_clean(text))
         | 
| 78 | 
            +
                    elif self.clean == 'lower':
         | 
| 79 | 
            +
                        text = whitespace_clean(basic_clean(text)).lower()
         | 
| 80 | 
            +
                    elif self.clean == 'canonicalize':
         | 
| 81 | 
            +
                        text = canonicalize(basic_clean(text))
         | 
| 82 | 
            +
                    return text
         | 
    	
        humo/models/wan_modules/vae.py
    ADDED
    
    | @@ -0,0 +1,666 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
         | 
| 2 | 
            +
            import logging
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            import torch.cuda.amp as amp
         | 
| 6 | 
            +
            import torch.nn as nn
         | 
| 7 | 
            +
            import torch.nn.functional as F
         | 
| 8 | 
            +
            from einops import rearrange
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            __all__ = [
         | 
| 11 | 
            +
                'WanVAE',
         | 
| 12 | 
            +
            ]
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            CACHE_T = 2
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            class CausalConv3d(nn.Conv3d):
         | 
| 18 | 
            +
                """
         | 
| 19 | 
            +
                Causal 3d convolusion.
         | 
| 20 | 
            +
                """
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                def __init__(self, *args, **kwargs):
         | 
| 23 | 
            +
                    super().__init__(*args, **kwargs)
         | 
| 24 | 
            +
                    self._padding = (self.padding[2], self.padding[2], self.padding[1],
         | 
| 25 | 
            +
                                     self.padding[1], 2 * self.padding[0], 0)
         | 
| 26 | 
            +
                    self.padding = (0, 0, 0)
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                def forward(self, x, cache_x=None):
         | 
| 29 | 
            +
                    padding = list(self._padding)
         | 
| 30 | 
            +
                    if cache_x is not None and self._padding[4] > 0:
         | 
| 31 | 
            +
                        cache_x = cache_x.to(x.device)
         | 
| 32 | 
            +
                        x = torch.cat([cache_x, x], dim=2)
         | 
| 33 | 
            +
                        padding[4] -= cache_x.shape[2]
         | 
| 34 | 
            +
                    x = F.pad(x, padding)
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                    return super().forward(x)
         | 
| 37 | 
            +
             | 
| 38 | 
            +
             | 
| 39 | 
            +
            class RMS_norm(nn.Module):
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                def __init__(self, dim, channel_first=True, images=True, bias=False):
         | 
| 42 | 
            +
                    super().__init__()
         | 
| 43 | 
            +
                    broadcastable_dims = (1, 1, 1) if not images else (1, 1)
         | 
| 44 | 
            +
                    shape = (dim, *broadcastable_dims) if channel_first else (dim,)
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                    self.channel_first = channel_first
         | 
| 47 | 
            +
                    self.scale = dim**0.5
         | 
| 48 | 
            +
                    self.gamma = nn.Parameter(torch.ones(shape))
         | 
| 49 | 
            +
                    self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                def forward(self, x):
         | 
| 52 | 
            +
                    return F.normalize(
         | 
| 53 | 
            +
                        x, dim=(1 if self.channel_first else
         | 
| 54 | 
            +
                                -1)) * self.scale * self.gamma + self.bias
         | 
| 55 | 
            +
             | 
| 56 | 
            +
             | 
| 57 | 
            +
            class Upsample(nn.Upsample):
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                def forward(self, x):
         | 
| 60 | 
            +
                    """
         | 
| 61 | 
            +
                    Fix bfloat16 support for nearest neighbor interpolation.
         | 
| 62 | 
            +
                    """
         | 
| 63 | 
            +
                    return super().forward(x.float()).type_as(x)
         | 
| 64 | 
            +
             | 
| 65 | 
            +
             | 
| 66 | 
            +
            class Resample(nn.Module):
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                def __init__(self, dim, mode):
         | 
| 69 | 
            +
                    assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
         | 
| 70 | 
            +
                                    'downsample3d')
         | 
| 71 | 
            +
                    super().__init__()
         | 
| 72 | 
            +
                    self.dim = dim
         | 
| 73 | 
            +
                    self.mode = mode
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                    # layers
         | 
| 76 | 
            +
                    if mode == 'upsample2d':
         | 
| 77 | 
            +
                        self.resample = nn.Sequential(
         | 
| 78 | 
            +
                            Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
         | 
| 79 | 
            +
                            nn.Conv2d(dim, dim // 2, 3, padding=1))
         | 
| 80 | 
            +
                    elif mode == 'upsample3d':
         | 
| 81 | 
            +
                        self.resample = nn.Sequential(
         | 
| 82 | 
            +
                            Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
         | 
| 83 | 
            +
                            nn.Conv2d(dim, dim // 2, 3, padding=1))
         | 
| 84 | 
            +
                        self.time_conv = CausalConv3d(
         | 
| 85 | 
            +
                            dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                    elif mode == 'downsample2d':
         | 
| 88 | 
            +
                        self.resample = nn.Sequential(
         | 
| 89 | 
            +
                            nn.ZeroPad2d((0, 1, 0, 1)),
         | 
| 90 | 
            +
                            nn.Conv2d(dim, dim, 3, stride=(2, 2)))
         | 
| 91 | 
            +
                    elif mode == 'downsample3d':
         | 
| 92 | 
            +
                        self.resample = nn.Sequential(
         | 
| 93 | 
            +
                            nn.ZeroPad2d((0, 1, 0, 1)),
         | 
| 94 | 
            +
                            nn.Conv2d(dim, dim, 3, stride=(2, 2)))
         | 
| 95 | 
            +
                        self.time_conv = CausalConv3d(
         | 
| 96 | 
            +
                            dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                    else:
         | 
| 99 | 
            +
                        self.resample = nn.Identity()
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                def forward(self, x, feat_cache=None, feat_idx=[0]):
         | 
| 102 | 
            +
                    b, c, t, h, w = x.size()
         | 
| 103 | 
            +
                    if self.mode == 'upsample3d':
         | 
| 104 | 
            +
                        if feat_cache is not None:
         | 
| 105 | 
            +
                            idx = feat_idx[0]
         | 
| 106 | 
            +
                            if feat_cache[idx] is None:
         | 
| 107 | 
            +
                                feat_cache[idx] = 'Rep'
         | 
| 108 | 
            +
                                feat_idx[0] += 1
         | 
| 109 | 
            +
                            else:
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                                cache_x = x[:, :, -CACHE_T:, :, :].clone()
         | 
| 112 | 
            +
                                if cache_x.shape[2] < 2 and feat_cache[
         | 
| 113 | 
            +
                                        idx] is not None and feat_cache[idx] != 'Rep':
         | 
| 114 | 
            +
                                    # cache last frame of last two chunk
         | 
| 115 | 
            +
                                    cache_x = torch.cat([
         | 
| 116 | 
            +
                                        feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
         | 
| 117 | 
            +
                                            cache_x.device), cache_x
         | 
| 118 | 
            +
                                    ],
         | 
| 119 | 
            +
                                                        dim=2)
         | 
| 120 | 
            +
                                if cache_x.shape[2] < 2 and feat_cache[
         | 
| 121 | 
            +
                                        idx] is not None and feat_cache[idx] == 'Rep':
         | 
| 122 | 
            +
                                    cache_x = torch.cat([
         | 
| 123 | 
            +
                                        torch.zeros_like(cache_x).to(cache_x.device),
         | 
| 124 | 
            +
                                        cache_x
         | 
| 125 | 
            +
                                    ],
         | 
| 126 | 
            +
                                                        dim=2)
         | 
| 127 | 
            +
                                if feat_cache[idx] == 'Rep':
         | 
| 128 | 
            +
                                    x = self.time_conv(x)
         | 
| 129 | 
            +
                                else:
         | 
| 130 | 
            +
                                    x = self.time_conv(x, feat_cache[idx])
         | 
| 131 | 
            +
                                feat_cache[idx] = cache_x
         | 
| 132 | 
            +
                                feat_idx[0] += 1
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                                x = x.reshape(b, 2, c, t, h, w)
         | 
| 135 | 
            +
                                x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
         | 
| 136 | 
            +
                                                3)
         | 
| 137 | 
            +
                                x = x.reshape(b, c, t * 2, h, w)
         | 
| 138 | 
            +
                    t = x.shape[2]
         | 
| 139 | 
            +
                    x = rearrange(x, 'b c t h w -> (b t) c h w')
         | 
| 140 | 
            +
                    x = self.resample(x)
         | 
| 141 | 
            +
                    x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                    if self.mode == 'downsample3d':
         | 
| 144 | 
            +
                        if feat_cache is not None:
         | 
| 145 | 
            +
                            idx = feat_idx[0]
         | 
| 146 | 
            +
                            if feat_cache[idx] is None:
         | 
| 147 | 
            +
                                feat_cache[idx] = x.clone()
         | 
| 148 | 
            +
                                feat_idx[0] += 1
         | 
| 149 | 
            +
                            else:
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                                cache_x = x[:, :, -1:, :, :].clone()
         | 
| 152 | 
            +
                                # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
         | 
| 153 | 
            +
                                #     # cache last frame of last two chunk
         | 
| 154 | 
            +
                                #     cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                                x = self.time_conv(
         | 
| 157 | 
            +
                                    torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
         | 
| 158 | 
            +
                                feat_cache[idx] = cache_x
         | 
| 159 | 
            +
                                feat_idx[0] += 1
         | 
| 160 | 
            +
                    return x
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                def init_weight(self, conv):
         | 
| 163 | 
            +
                    conv_weight = conv.weight
         | 
| 164 | 
            +
                    nn.init.zeros_(conv_weight)
         | 
| 165 | 
            +
                    c1, c2, t, h, w = conv_weight.size()
         | 
| 166 | 
            +
                    one_matrix = torch.eye(c1, c2)
         | 
| 167 | 
            +
                    init_matrix = one_matrix
         | 
| 168 | 
            +
                    nn.init.zeros_(conv_weight)
         | 
| 169 | 
            +
                    #conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
         | 
| 170 | 
            +
                    conv_weight.data[:, :, 1, 0, 0] = init_matrix  #* 0.5
         | 
| 171 | 
            +
                    conv.weight.data.copy_(conv_weight)
         | 
| 172 | 
            +
                    nn.init.zeros_(conv.bias.data)
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                def init_weight2(self, conv):
         | 
| 175 | 
            +
                    conv_weight = conv.weight.data
         | 
| 176 | 
            +
                    nn.init.zeros_(conv_weight)
         | 
| 177 | 
            +
                    c1, c2, t, h, w = conv_weight.size()
         | 
| 178 | 
            +
                    init_matrix = torch.eye(c1 // 2, c2)
         | 
| 179 | 
            +
                    #init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
         | 
| 180 | 
            +
                    conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
         | 
| 181 | 
            +
                    conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
         | 
| 182 | 
            +
                    conv.weight.data.copy_(conv_weight)
         | 
| 183 | 
            +
                    nn.init.zeros_(conv.bias.data)
         | 
| 184 | 
            +
             | 
| 185 | 
            +
             | 
| 186 | 
            +
            class ResidualBlock(nn.Module):
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                def __init__(self, in_dim, out_dim, dropout=0.0):
         | 
| 189 | 
            +
                    super().__init__()
         | 
| 190 | 
            +
                    self.in_dim = in_dim
         | 
| 191 | 
            +
                    self.out_dim = out_dim
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                    # layers
         | 
| 194 | 
            +
                    self.residual = nn.Sequential(
         | 
| 195 | 
            +
                        RMS_norm(in_dim, images=False), nn.SiLU(),
         | 
| 196 | 
            +
                        CausalConv3d(in_dim, out_dim, 3, padding=1),
         | 
| 197 | 
            +
                        RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
         | 
| 198 | 
            +
                        CausalConv3d(out_dim, out_dim, 3, padding=1))
         | 
| 199 | 
            +
                    self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
         | 
| 200 | 
            +
                        if in_dim != out_dim else nn.Identity()
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                def forward(self, x, feat_cache=None, feat_idx=[0]):
         | 
| 203 | 
            +
                    h = self.shortcut(x)
         | 
| 204 | 
            +
                    for layer in self.residual:
         | 
| 205 | 
            +
                        if isinstance(layer, CausalConv3d) and feat_cache is not None:
         | 
| 206 | 
            +
                            idx = feat_idx[0]
         | 
| 207 | 
            +
                            cache_x = x[:, :, -CACHE_T:, :, :].clone()
         | 
| 208 | 
            +
                            if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
         | 
| 209 | 
            +
                                # cache last frame of last two chunk
         | 
| 210 | 
            +
                                cache_x = torch.cat([
         | 
| 211 | 
            +
                                    feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
         | 
| 212 | 
            +
                                        cache_x.device), cache_x
         | 
| 213 | 
            +
                                ],
         | 
| 214 | 
            +
                                                    dim=2)
         | 
| 215 | 
            +
                            x = layer(x, feat_cache[idx])
         | 
| 216 | 
            +
                            feat_cache[idx] = cache_x
         | 
| 217 | 
            +
                            feat_idx[0] += 1
         | 
| 218 | 
            +
                        else:
         | 
| 219 | 
            +
                            x = layer(x)
         | 
| 220 | 
            +
                    return x + h
         | 
| 221 | 
            +
             | 
| 222 | 
            +
             | 
| 223 | 
            +
            class AttentionBlock(nn.Module):
         | 
| 224 | 
            +
                """
         | 
| 225 | 
            +
                Causal self-attention with a single head.
         | 
| 226 | 
            +
                """
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                def __init__(self, dim):
         | 
| 229 | 
            +
                    super().__init__()
         | 
| 230 | 
            +
                    self.dim = dim
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                    # layers
         | 
| 233 | 
            +
                    self.norm = RMS_norm(dim)
         | 
| 234 | 
            +
                    self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
         | 
| 235 | 
            +
                    self.proj = nn.Conv2d(dim, dim, 1)
         | 
| 236 | 
            +
             | 
| 237 | 
            +
                    # zero out the last layer params
         | 
| 238 | 
            +
                    nn.init.zeros_(self.proj.weight)
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                def forward(self, x):
         | 
| 241 | 
            +
                    identity = x
         | 
| 242 | 
            +
                    b, c, t, h, w = x.size()
         | 
| 243 | 
            +
                    x = rearrange(x, 'b c t h w -> (b t) c h w')
         | 
| 244 | 
            +
                    x = self.norm(x)
         | 
| 245 | 
            +
                    # compute query, key, value
         | 
| 246 | 
            +
                    q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3,
         | 
| 247 | 
            +
                                                     -1).permute(0, 1, 3,
         | 
| 248 | 
            +
                                                                 2).contiguous().chunk(
         | 
| 249 | 
            +
                                                                     3, dim=-1)
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                    # apply attention
         | 
| 252 | 
            +
                    x = F.scaled_dot_product_attention(
         | 
| 253 | 
            +
                        q,
         | 
| 254 | 
            +
                        k,
         | 
| 255 | 
            +
                        v,
         | 
| 256 | 
            +
                    )
         | 
| 257 | 
            +
                    x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
         | 
| 258 | 
            +
             | 
| 259 | 
            +
                    # output
         | 
| 260 | 
            +
                    x = self.proj(x)
         | 
| 261 | 
            +
                    x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
         | 
| 262 | 
            +
                    return x + identity
         | 
| 263 | 
            +
             | 
| 264 | 
            +
             | 
| 265 | 
            +
            class Encoder3d(nn.Module):
         | 
| 266 | 
            +
             | 
| 267 | 
            +
                def __init__(self,
         | 
| 268 | 
            +
                             dim=128,
         | 
| 269 | 
            +
                             z_dim=4,
         | 
| 270 | 
            +
                             dim_mult=[1, 2, 4, 4],
         | 
| 271 | 
            +
                             num_res_blocks=2,
         | 
| 272 | 
            +
                             attn_scales=[],
         | 
| 273 | 
            +
                             temperal_downsample=[True, True, False],
         | 
| 274 | 
            +
                             dropout=0.0):
         | 
| 275 | 
            +
                    super().__init__()
         | 
| 276 | 
            +
                    self.dim = dim
         | 
| 277 | 
            +
                    self.z_dim = z_dim
         | 
| 278 | 
            +
                    self.dim_mult = dim_mult
         | 
| 279 | 
            +
                    self.num_res_blocks = num_res_blocks
         | 
| 280 | 
            +
                    self.attn_scales = attn_scales
         | 
| 281 | 
            +
                    self.temperal_downsample = temperal_downsample
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                    # dimensions
         | 
| 284 | 
            +
                    dims = [dim * u for u in [1] + dim_mult]
         | 
| 285 | 
            +
                    scale = 1.0
         | 
| 286 | 
            +
             | 
| 287 | 
            +
                    # init block
         | 
| 288 | 
            +
                    self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
         | 
| 289 | 
            +
             | 
| 290 | 
            +
                    # downsample blocks
         | 
| 291 | 
            +
                    downsamples = []
         | 
| 292 | 
            +
                    for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
         | 
| 293 | 
            +
                        # residual (+attention) blocks
         | 
| 294 | 
            +
                        for _ in range(num_res_blocks):
         | 
| 295 | 
            +
                            downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
         | 
| 296 | 
            +
                            if scale in attn_scales:
         | 
| 297 | 
            +
                                downsamples.append(AttentionBlock(out_dim))
         | 
| 298 | 
            +
                            in_dim = out_dim
         | 
| 299 | 
            +
             | 
| 300 | 
            +
                        # downsample block
         | 
| 301 | 
            +
                        if i != len(dim_mult) - 1:
         | 
| 302 | 
            +
                            mode = 'downsample3d' if temperal_downsample[
         | 
| 303 | 
            +
                                i] else 'downsample2d'
         | 
| 304 | 
            +
                            downsamples.append(Resample(out_dim, mode=mode))
         | 
| 305 | 
            +
                            scale /= 2.0
         | 
| 306 | 
            +
                    self.downsamples = nn.Sequential(*downsamples)
         | 
| 307 | 
            +
             | 
| 308 | 
            +
                    # middle blocks
         | 
| 309 | 
            +
                    self.middle = nn.Sequential(
         | 
| 310 | 
            +
                        ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim),
         | 
| 311 | 
            +
                        ResidualBlock(out_dim, out_dim, dropout))
         | 
| 312 | 
            +
             | 
| 313 | 
            +
                    # output blocks
         | 
| 314 | 
            +
                    self.head = nn.Sequential(
         | 
| 315 | 
            +
                        RMS_norm(out_dim, images=False), nn.SiLU(),
         | 
| 316 | 
            +
                        CausalConv3d(out_dim, z_dim, 3, padding=1))
         | 
| 317 | 
            +
             | 
| 318 | 
            +
                def forward(self, x, feat_cache=None, feat_idx=[0]):
         | 
| 319 | 
            +
                    if feat_cache is not None:
         | 
| 320 | 
            +
                        idx = feat_idx[0]
         | 
| 321 | 
            +
                        cache_x = x[:, :, -CACHE_T:, :, :].clone()
         | 
| 322 | 
            +
                        if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
         | 
| 323 | 
            +
                            # cache last frame of last two chunk
         | 
| 324 | 
            +
                            cache_x = torch.cat([
         | 
| 325 | 
            +
                                feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
         | 
| 326 | 
            +
                                    cache_x.device), cache_x
         | 
| 327 | 
            +
                            ],
         | 
| 328 | 
            +
                                                dim=2)
         | 
| 329 | 
            +
                        x = self.conv1(x, feat_cache[idx])
         | 
| 330 | 
            +
                        feat_cache[idx] = cache_x
         | 
| 331 | 
            +
                        feat_idx[0] += 1
         | 
| 332 | 
            +
                    else:
         | 
| 333 | 
            +
                        x = self.conv1(x)
         | 
| 334 | 
            +
             | 
| 335 | 
            +
                    ## downsamples
         | 
| 336 | 
            +
                    for layer in self.downsamples:
         | 
| 337 | 
            +
                        if feat_cache is not None:
         | 
| 338 | 
            +
                            x = layer(x, feat_cache, feat_idx)
         | 
| 339 | 
            +
                        else:
         | 
| 340 | 
            +
                            x = layer(x)
         | 
| 341 | 
            +
             | 
| 342 | 
            +
                    ## middle
         | 
| 343 | 
            +
                    for layer in self.middle:
         | 
| 344 | 
            +
                        if isinstance(layer, ResidualBlock) and feat_cache is not None:
         | 
| 345 | 
            +
                            x = layer(x, feat_cache, feat_idx)
         | 
| 346 | 
            +
                        else:
         | 
| 347 | 
            +
                            x = layer(x)
         | 
| 348 | 
            +
             | 
| 349 | 
            +
                    ## head
         | 
| 350 | 
            +
                    for layer in self.head:
         | 
| 351 | 
            +
                        if isinstance(layer, CausalConv3d) and feat_cache is not None:
         | 
| 352 | 
            +
                            idx = feat_idx[0]
         | 
| 353 | 
            +
                            cache_x = x[:, :, -CACHE_T:, :, :].clone()
         | 
| 354 | 
            +
                            if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
         | 
| 355 | 
            +
                                # cache last frame of last two chunk
         | 
| 356 | 
            +
                                cache_x = torch.cat([
         | 
| 357 | 
            +
                                    feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
         | 
| 358 | 
            +
                                        cache_x.device), cache_x
         | 
| 359 | 
            +
                                ],
         | 
| 360 | 
            +
                                                    dim=2)
         | 
| 361 | 
            +
                            x = layer(x, feat_cache[idx])
         | 
| 362 | 
            +
                            feat_cache[idx] = cache_x
         | 
| 363 | 
            +
                            feat_idx[0] += 1
         | 
| 364 | 
            +
                        else:
         | 
| 365 | 
            +
                            x = layer(x)
         | 
| 366 | 
            +
                    return x
         | 
| 367 | 
            +
             | 
| 368 | 
            +
             | 
| 369 | 
            +
            class Decoder3d(nn.Module):
         | 
| 370 | 
            +
             | 
| 371 | 
            +
                def __init__(self,
         | 
| 372 | 
            +
                             dim=128,
         | 
| 373 | 
            +
                             z_dim=4,
         | 
| 374 | 
            +
                             dim_mult=[1, 2, 4, 4],
         | 
| 375 | 
            +
                             num_res_blocks=2,
         | 
| 376 | 
            +
                             attn_scales=[],
         | 
| 377 | 
            +
                             temperal_upsample=[False, True, True],
         | 
| 378 | 
            +
                             dropout=0.0):
         | 
| 379 | 
            +
                    super().__init__()
         | 
| 380 | 
            +
                    self.dim = dim
         | 
| 381 | 
            +
                    self.z_dim = z_dim
         | 
| 382 | 
            +
                    self.dim_mult = dim_mult
         | 
| 383 | 
            +
                    self.num_res_blocks = num_res_blocks
         | 
| 384 | 
            +
                    self.attn_scales = attn_scales
         | 
| 385 | 
            +
                    self.temperal_upsample = temperal_upsample
         | 
| 386 | 
            +
             | 
| 387 | 
            +
                    # dimensions
         | 
| 388 | 
            +
                    dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
         | 
| 389 | 
            +
                    scale = 1.0 / 2**(len(dim_mult) - 2)
         | 
| 390 | 
            +
             | 
| 391 | 
            +
                    # init block
         | 
| 392 | 
            +
                    self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
         | 
| 393 | 
            +
             | 
| 394 | 
            +
                    # middle blocks
         | 
| 395 | 
            +
                    self.middle = nn.Sequential(
         | 
| 396 | 
            +
                        ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
         | 
| 397 | 
            +
                        ResidualBlock(dims[0], dims[0], dropout))
         | 
| 398 | 
            +
             | 
| 399 | 
            +
                    # upsample blocks
         | 
| 400 | 
            +
                    upsamples = []
         | 
| 401 | 
            +
                    for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
         | 
| 402 | 
            +
                        # residual (+attention) blocks
         | 
| 403 | 
            +
                        if i == 1 or i == 2 or i == 3:
         | 
| 404 | 
            +
                            in_dim = in_dim // 2
         | 
| 405 | 
            +
                        for _ in range(num_res_blocks + 1):
         | 
| 406 | 
            +
                            upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
         | 
| 407 | 
            +
                            if scale in attn_scales:
         | 
| 408 | 
            +
                                upsamples.append(AttentionBlock(out_dim))
         | 
| 409 | 
            +
                            in_dim = out_dim
         | 
| 410 | 
            +
             | 
| 411 | 
            +
                        # upsample block
         | 
| 412 | 
            +
                        if i != len(dim_mult) - 1:
         | 
| 413 | 
            +
                            mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
         | 
| 414 | 
            +
                            upsamples.append(Resample(out_dim, mode=mode))
         | 
| 415 | 
            +
                            scale *= 2.0
         | 
| 416 | 
            +
                    self.upsamples = nn.Sequential(*upsamples)
         | 
| 417 | 
            +
             | 
| 418 | 
            +
                    # output blocks
         | 
| 419 | 
            +
                    self.head = nn.Sequential(
         | 
| 420 | 
            +
                        RMS_norm(out_dim, images=False), nn.SiLU(),
         | 
| 421 | 
            +
                        CausalConv3d(out_dim, 3, 3, padding=1))
         | 
| 422 | 
            +
             | 
| 423 | 
            +
                def forward(self, x, feat_cache=None, feat_idx=[0]):
         | 
| 424 | 
            +
                    ## conv1
         | 
| 425 | 
            +
                    if feat_cache is not None:
         | 
| 426 | 
            +
                        idx = feat_idx[0]
         | 
| 427 | 
            +
                        cache_x = x[:, :, -CACHE_T:, :, :].clone()
         | 
| 428 | 
            +
                        if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
         | 
| 429 | 
            +
                            # cache last frame of last two chunk
         | 
| 430 | 
            +
                            cache_x = torch.cat([
         | 
| 431 | 
            +
                                feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
         | 
| 432 | 
            +
                                    cache_x.device), cache_x
         | 
| 433 | 
            +
                            ],
         | 
| 434 | 
            +
                                                dim=2)
         | 
| 435 | 
            +
                        x = self.conv1(x, feat_cache[idx])
         | 
| 436 | 
            +
                        feat_cache[idx] = cache_x
         | 
| 437 | 
            +
                        feat_idx[0] += 1
         | 
| 438 | 
            +
                    else:
         | 
| 439 | 
            +
                        x = self.conv1(x)
         | 
| 440 | 
            +
             | 
| 441 | 
            +
                    ## middle
         | 
| 442 | 
            +
                    for layer in self.middle:
         | 
| 443 | 
            +
                        if isinstance(layer, ResidualBlock) and feat_cache is not None:
         | 
| 444 | 
            +
                            x = layer(x, feat_cache, feat_idx)
         | 
| 445 | 
            +
                        else:
         | 
| 446 | 
            +
                            x = layer(x)
         | 
| 447 | 
            +
             | 
| 448 | 
            +
                    ## upsamples
         | 
| 449 | 
            +
                    for layer in self.upsamples:
         | 
| 450 | 
            +
                        if feat_cache is not None:
         | 
| 451 | 
            +
                            x = layer(x, feat_cache, feat_idx)
         | 
| 452 | 
            +
                        else:
         | 
| 453 | 
            +
                            x = layer(x)
         | 
| 454 | 
            +
             | 
| 455 | 
            +
                    ## head
         | 
| 456 | 
            +
                    for layer in self.head:
         | 
| 457 | 
            +
                        if isinstance(layer, CausalConv3d) and feat_cache is not None:
         | 
| 458 | 
            +
                            idx = feat_idx[0]
         | 
| 459 | 
            +
                            cache_x = x[:, :, -CACHE_T:, :, :].clone()
         | 
| 460 | 
            +
                            if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
         | 
| 461 | 
            +
                                # cache last frame of last two chunk
         | 
| 462 | 
            +
                                cache_x = torch.cat([
         | 
| 463 | 
            +
                                    feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
         | 
| 464 | 
            +
                                        cache_x.device), cache_x
         | 
| 465 | 
            +
                                ],
         | 
| 466 | 
            +
                                                    dim=2)
         | 
| 467 | 
            +
                            x = layer(x, feat_cache[idx])
         | 
| 468 | 
            +
                            feat_cache[idx] = cache_x
         | 
| 469 | 
            +
                            feat_idx[0] += 1
         | 
| 470 | 
            +
                        else:
         | 
| 471 | 
            +
                            x = layer(x)
         | 
| 472 | 
            +
                    return x
         | 
| 473 | 
            +
             | 
| 474 | 
            +
             | 
| 475 | 
            +
            def count_conv3d(model):
         | 
| 476 | 
            +
                count = 0
         | 
| 477 | 
            +
                for m in model.modules():
         | 
| 478 | 
            +
                    if isinstance(m, CausalConv3d):
         | 
| 479 | 
            +
                        count += 1
         | 
| 480 | 
            +
                return count
         | 
| 481 | 
            +
             | 
| 482 | 
            +
             | 
| 483 | 
            +
            class WanVAE_(nn.Module):
         | 
| 484 | 
            +
             | 
| 485 | 
            +
                def __init__(self,
         | 
| 486 | 
            +
                             dim=128,
         | 
| 487 | 
            +
                             z_dim=4,
         | 
| 488 | 
            +
                             dim_mult=[1, 2, 4, 4],
         | 
| 489 | 
            +
                             num_res_blocks=2,
         | 
| 490 | 
            +
                             attn_scales=[],
         | 
| 491 | 
            +
                             temperal_downsample=[True, True, False],
         | 
| 492 | 
            +
                             dropout=0.0):
         | 
| 493 | 
            +
                    super().__init__()
         | 
| 494 | 
            +
                    self.dim = dim
         | 
| 495 | 
            +
                    self.z_dim = z_dim
         | 
| 496 | 
            +
                    self.dim_mult = dim_mult
         | 
| 497 | 
            +
                    self.num_res_blocks = num_res_blocks
         | 
| 498 | 
            +
                    self.attn_scales = attn_scales
         | 
| 499 | 
            +
                    self.temperal_downsample = temperal_downsample
         | 
| 500 | 
            +
                    self.temperal_upsample = temperal_downsample[::-1]
         | 
| 501 | 
            +
             | 
| 502 | 
            +
                    # modules
         | 
| 503 | 
            +
                    self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
         | 
| 504 | 
            +
                                             attn_scales, self.temperal_downsample, dropout)
         | 
| 505 | 
            +
                    self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
         | 
| 506 | 
            +
                    self.conv2 = CausalConv3d(z_dim, z_dim, 1)
         | 
| 507 | 
            +
                    self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
         | 
| 508 | 
            +
                                             attn_scales, self.temperal_upsample, dropout)
         | 
| 509 | 
            +
             | 
| 510 | 
            +
                def forward(self, x):
         | 
| 511 | 
            +
                    mu, log_var = self.encode(x)
         | 
| 512 | 
            +
                    z = self.reparameterize(mu, log_var)
         | 
| 513 | 
            +
                    x_recon = self.decode(z)
         | 
| 514 | 
            +
                    return x_recon, mu, log_var
         | 
| 515 | 
            +
             | 
| 516 | 
            +
                def encode(self, x, scale):
         | 
| 517 | 
            +
                    self.clear_cache()
         | 
| 518 | 
            +
                    ## cache
         | 
| 519 | 
            +
                    t = x.shape[2]
         | 
| 520 | 
            +
                    iter_ = 1 + (t - 1) // 4
         | 
| 521 | 
            +
                    ## 对encode输入的x,按时间拆分为1、4、4、4....
         | 
| 522 | 
            +
                    for i in range(iter_):
         | 
| 523 | 
            +
                        self._enc_conv_idx = [0]
         | 
| 524 | 
            +
                        if i == 0:
         | 
| 525 | 
            +
                            out = self.encoder(
         | 
| 526 | 
            +
                                x[:, :, :1, :, :],
         | 
| 527 | 
            +
                                feat_cache=self._enc_feat_map,
         | 
| 528 | 
            +
                                feat_idx=self._enc_conv_idx)
         | 
| 529 | 
            +
                        else:
         | 
| 530 | 
            +
                            out_ = self.encoder(
         | 
| 531 | 
            +
                                x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
         | 
| 532 | 
            +
                                feat_cache=self._enc_feat_map,
         | 
| 533 | 
            +
                                feat_idx=self._enc_conv_idx)
         | 
| 534 | 
            +
                            out = torch.cat([out, out_], 2)
         | 
| 535 | 
            +
                    mu, log_var = self.conv1(out).chunk(2, dim=1)
         | 
| 536 | 
            +
                    if isinstance(scale[0], torch.Tensor):
         | 
| 537 | 
            +
                        mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
         | 
| 538 | 
            +
                            1, self.z_dim, 1, 1, 1)
         | 
| 539 | 
            +
                    else:
         | 
| 540 | 
            +
                        mu = (mu - scale[0]) * scale[1]
         | 
| 541 | 
            +
                    self.clear_cache()
         | 
| 542 | 
            +
                    return mu
         | 
| 543 | 
            +
             | 
| 544 | 
            +
                def decode(self, z, scale):
         | 
| 545 | 
            +
                    self.clear_cache()
         | 
| 546 | 
            +
                    # z: [b,c,t,h,w]
         | 
| 547 | 
            +
                    if isinstance(scale[0], torch.Tensor):
         | 
| 548 | 
            +
                        z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
         | 
| 549 | 
            +
                            1, self.z_dim, 1, 1, 1)
         | 
| 550 | 
            +
                    else:
         | 
| 551 | 
            +
                        z = z / scale[1] + scale[0]
         | 
| 552 | 
            +
                    iter_ = z.shape[2]
         | 
| 553 | 
            +
                    x = self.conv2(z)
         | 
| 554 | 
            +
                    for i in range(iter_):
         | 
| 555 | 
            +
                        self._conv_idx = [0]
         | 
| 556 | 
            +
                        if i == 0:
         | 
| 557 | 
            +
                            out = self.decoder(
         | 
| 558 | 
            +
                                x[:, :, i:i + 1, :, :],
         | 
| 559 | 
            +
                                feat_cache=self._feat_map,
         | 
| 560 | 
            +
                                feat_idx=self._conv_idx)
         | 
| 561 | 
            +
                        else:
         | 
| 562 | 
            +
                            out_ = self.decoder(
         | 
| 563 | 
            +
                                x[:, :, i:i + 1, :, :],
         | 
| 564 | 
            +
                                feat_cache=self._feat_map,
         | 
| 565 | 
            +
                                feat_idx=self._conv_idx)
         | 
| 566 | 
            +
                            out = torch.cat([out, out_], 2)
         | 
| 567 | 
            +
                    self.clear_cache()
         | 
| 568 | 
            +
                    return out
         | 
| 569 | 
            +
             | 
| 570 | 
            +
                def reparameterize(self, mu, log_var):
         | 
| 571 | 
            +
                    std = torch.exp(0.5 * log_var)
         | 
| 572 | 
            +
                    eps = torch.randn_like(std)
         | 
| 573 | 
            +
                    return eps * std + mu
         | 
| 574 | 
            +
             | 
| 575 | 
            +
                def sample(self, imgs, deterministic=False):
         | 
| 576 | 
            +
                    mu, log_var = self.encode(imgs)
         | 
| 577 | 
            +
                    if deterministic:
         | 
| 578 | 
            +
                        return mu
         | 
| 579 | 
            +
                    std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
         | 
| 580 | 
            +
                    return mu + std * torch.randn_like(std)
         | 
| 581 | 
            +
             | 
| 582 | 
            +
                def clear_cache(self):
         | 
| 583 | 
            +
                    self._conv_num = count_conv3d(self.decoder)
         | 
| 584 | 
            +
                    self._conv_idx = [0]
         | 
| 585 | 
            +
                    self._feat_map = [None] * self._conv_num
         | 
| 586 | 
            +
                    #cache encode
         | 
| 587 | 
            +
                    self._enc_conv_num = count_conv3d(self.encoder)
         | 
| 588 | 
            +
                    self._enc_conv_idx = [0]
         | 
| 589 | 
            +
                    self._enc_feat_map = [None] * self._enc_conv_num
         | 
| 590 | 
            +
             | 
| 591 | 
            +
             | 
| 592 | 
            +
            def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs):
         | 
| 593 | 
            +
                """
         | 
| 594 | 
            +
                Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
         | 
| 595 | 
            +
                """
         | 
| 596 | 
            +
                # params
         | 
| 597 | 
            +
                cfg = dict(
         | 
| 598 | 
            +
                    dim=96,
         | 
| 599 | 
            +
                    z_dim=z_dim,
         | 
| 600 | 
            +
                    dim_mult=[1, 2, 4, 4],
         | 
| 601 | 
            +
                    num_res_blocks=2,
         | 
| 602 | 
            +
                    attn_scales=[],
         | 
| 603 | 
            +
                    temperal_downsample=[False, True, True],
         | 
| 604 | 
            +
                    dropout=0.0)
         | 
| 605 | 
            +
                cfg.update(**kwargs)
         | 
| 606 | 
            +
             | 
| 607 | 
            +
                # init model
         | 
| 608 | 
            +
                # with torch.device('meta'):
         | 
| 609 | 
            +
                model = WanVAE_(**cfg)
         | 
| 610 | 
            +
             | 
| 611 | 
            +
                # load checkpoint
         | 
| 612 | 
            +
                logging.info(f'loading {pretrained_path}')
         | 
| 613 | 
            +
                if pretrained_path is not None:
         | 
| 614 | 
            +
                    model.load_state_dict(torch.load(pretrained_path, map_location=device), assign=True)
         | 
| 615 | 
            +
             | 
| 616 | 
            +
                return model
         | 
| 617 | 
            +
             | 
| 618 | 
            +
             | 
| 619 | 
            +
            class WanVAE:
         | 
| 620 | 
            +
             | 
| 621 | 
            +
                def __init__(self,
         | 
| 622 | 
            +
                             z_dim=16,
         | 
| 623 | 
            +
                             vae_pth=None,
         | 
| 624 | 
            +
                             dtype=torch.float,
         | 
| 625 | 
            +
                             device="cuda"):
         | 
| 626 | 
            +
                    self.dtype = dtype
         | 
| 627 | 
            +
                    self.device = device
         | 
| 628 | 
            +
             | 
| 629 | 
            +
                    mean = [
         | 
| 630 | 
            +
                        -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
         | 
| 631 | 
            +
                        0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
         | 
| 632 | 
            +
                    ]
         | 
| 633 | 
            +
                    std = [
         | 
| 634 | 
            +
                        2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
         | 
| 635 | 
            +
                        3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
         | 
| 636 | 
            +
                    ]
         | 
| 637 | 
            +
                    self.mean = torch.tensor(mean, dtype=dtype, device=device)
         | 
| 638 | 
            +
                    self.std = torch.tensor(std, dtype=dtype, device=device)
         | 
| 639 | 
            +
                    self.scale = [self.mean, 1.0 / self.std]
         | 
| 640 | 
            +
             | 
| 641 | 
            +
                    # init model
         | 
| 642 | 
            +
                    self.model = _video_vae(
         | 
| 643 | 
            +
                        pretrained_path=vae_pth,
         | 
| 644 | 
            +
                        z_dim=z_dim,
         | 
| 645 | 
            +
                    ).eval().requires_grad_(False).to(device)
         | 
| 646 | 
            +
             | 
| 647 | 
            +
                @torch.no_grad()
         | 
| 648 | 
            +
                def encode(self, videos, device):
         | 
| 649 | 
            +
                    """
         | 
| 650 | 
            +
                    videos: A list of videos each with shape [C, T, H, W].
         | 
| 651 | 
            +
                    """
         | 
| 652 | 
            +
             | 
| 653 | 
            +
                    with torch.amp.autocast('cuda', dtype=self.dtype):
         | 
| 654 | 
            +
                        return [
         | 
| 655 | 
            +
                            self.model.encode(u.unsqueeze(0).to(device,self.dtype), self.scale).float().squeeze(0)
         | 
| 656 | 
            +
                            for u in videos
         | 
| 657 | 
            +
                        ]
         | 
| 658 | 
            +
             | 
| 659 | 
            +
                @torch.no_grad()
         | 
| 660 | 
            +
                def decode(self, zs):
         | 
| 661 | 
            +
                    with torch.amp.autocast('cuda', dtype=self.dtype):
         | 
| 662 | 
            +
                        return [
         | 
| 663 | 
            +
                            self.model.decode(u.unsqueeze(0),
         | 
| 664 | 
            +
                                              self.scale).float().clamp_(-1, 1).squeeze(0)
         | 
| 665 | 
            +
                            for u in zs
         | 
| 666 | 
            +
                        ]
         | 
    	
        humo/models/wan_modules/xlm_roberta.py
    ADDED
    
    | @@ -0,0 +1,170 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Modified from transformers.models.xlm_roberta.modeling_xlm_roberta
         | 
| 2 | 
            +
            # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import torch.nn as nn
         | 
| 5 | 
            +
            import torch.nn.functional as F
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            __all__ = ['XLMRoberta', 'xlm_roberta_large']
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            class SelfAttention(nn.Module):
         | 
| 11 | 
            +
             | 
| 12 | 
            +
                def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
         | 
| 13 | 
            +
                    assert dim % num_heads == 0
         | 
| 14 | 
            +
                    super().__init__()
         | 
| 15 | 
            +
                    self.dim = dim
         | 
| 16 | 
            +
                    self.num_heads = num_heads
         | 
| 17 | 
            +
                    self.head_dim = dim // num_heads
         | 
| 18 | 
            +
                    self.eps = eps
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                    # layers
         | 
| 21 | 
            +
                    self.q = nn.Linear(dim, dim)
         | 
| 22 | 
            +
                    self.k = nn.Linear(dim, dim)
         | 
| 23 | 
            +
                    self.v = nn.Linear(dim, dim)
         | 
| 24 | 
            +
                    self.o = nn.Linear(dim, dim)
         | 
| 25 | 
            +
                    self.dropout = nn.Dropout(dropout)
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                def forward(self, x, mask):
         | 
| 28 | 
            +
                    """
         | 
| 29 | 
            +
                    x:   [B, L, C].
         | 
| 30 | 
            +
                    """
         | 
| 31 | 
            +
                    b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                    # compute query, key, value
         | 
| 34 | 
            +
                    q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
         | 
| 35 | 
            +
                    k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
         | 
| 36 | 
            +
                    v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                    # compute attention
         | 
| 39 | 
            +
                    p = self.dropout.p if self.training else 0.0
         | 
| 40 | 
            +
                    x = F.scaled_dot_product_attention(q, k, v, mask, p)
         | 
| 41 | 
            +
                    x = x.permute(0, 2, 1, 3).reshape(b, s, c)
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                    # output
         | 
| 44 | 
            +
                    x = self.o(x)
         | 
| 45 | 
            +
                    x = self.dropout(x)
         | 
| 46 | 
            +
                    return x
         | 
| 47 | 
            +
             | 
| 48 | 
            +
             | 
| 49 | 
            +
            class AttentionBlock(nn.Module):
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
         | 
| 52 | 
            +
                    super().__init__()
         | 
| 53 | 
            +
                    self.dim = dim
         | 
| 54 | 
            +
                    self.num_heads = num_heads
         | 
| 55 | 
            +
                    self.post_norm = post_norm
         | 
| 56 | 
            +
                    self.eps = eps
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                    # layers
         | 
| 59 | 
            +
                    self.attn = SelfAttention(dim, num_heads, dropout, eps)
         | 
| 60 | 
            +
                    self.norm1 = nn.LayerNorm(dim, eps=eps)
         | 
| 61 | 
            +
                    self.ffn = nn.Sequential(
         | 
| 62 | 
            +
                        nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim),
         | 
| 63 | 
            +
                        nn.Dropout(dropout))
         | 
| 64 | 
            +
                    self.norm2 = nn.LayerNorm(dim, eps=eps)
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                def forward(self, x, mask):
         | 
| 67 | 
            +
                    if self.post_norm:
         | 
| 68 | 
            +
                        x = self.norm1(x + self.attn(x, mask))
         | 
| 69 | 
            +
                        x = self.norm2(x + self.ffn(x))
         | 
| 70 | 
            +
                    else:
         | 
| 71 | 
            +
                        x = x + self.attn(self.norm1(x), mask)
         | 
| 72 | 
            +
                        x = x + self.ffn(self.norm2(x))
         | 
| 73 | 
            +
                    return x
         | 
| 74 | 
            +
             | 
| 75 | 
            +
             | 
| 76 | 
            +
            class XLMRoberta(nn.Module):
         | 
| 77 | 
            +
                """
         | 
| 78 | 
            +
                XLMRobertaModel with no pooler and no LM head.
         | 
| 79 | 
            +
                """
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                def __init__(self,
         | 
| 82 | 
            +
                             vocab_size=250002,
         | 
| 83 | 
            +
                             max_seq_len=514,
         | 
| 84 | 
            +
                             type_size=1,
         | 
| 85 | 
            +
                             pad_id=1,
         | 
| 86 | 
            +
                             dim=1024,
         | 
| 87 | 
            +
                             num_heads=16,
         | 
| 88 | 
            +
                             num_layers=24,
         | 
| 89 | 
            +
                             post_norm=True,
         | 
| 90 | 
            +
                             dropout=0.1,
         | 
| 91 | 
            +
                             eps=1e-5):
         | 
| 92 | 
            +
                    super().__init__()
         | 
| 93 | 
            +
                    self.vocab_size = vocab_size
         | 
| 94 | 
            +
                    self.max_seq_len = max_seq_len
         | 
| 95 | 
            +
                    self.type_size = type_size
         | 
| 96 | 
            +
                    self.pad_id = pad_id
         | 
| 97 | 
            +
                    self.dim = dim
         | 
| 98 | 
            +
                    self.num_heads = num_heads
         | 
| 99 | 
            +
                    self.num_layers = num_layers
         | 
| 100 | 
            +
                    self.post_norm = post_norm
         | 
| 101 | 
            +
                    self.eps = eps
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                    # embeddings
         | 
| 104 | 
            +
                    self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)
         | 
| 105 | 
            +
                    self.type_embedding = nn.Embedding(type_size, dim)
         | 
| 106 | 
            +
                    self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)
         | 
| 107 | 
            +
                    self.dropout = nn.Dropout(dropout)
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                    # blocks
         | 
| 110 | 
            +
                    self.blocks = nn.ModuleList([
         | 
| 111 | 
            +
                        AttentionBlock(dim, num_heads, post_norm, dropout, eps)
         | 
| 112 | 
            +
                        for _ in range(num_layers)
         | 
| 113 | 
            +
                    ])
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                    # norm layer
         | 
| 116 | 
            +
                    self.norm = nn.LayerNorm(dim, eps=eps)
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                def forward(self, ids):
         | 
| 119 | 
            +
                    """
         | 
| 120 | 
            +
                    ids: [B, L] of torch.LongTensor.
         | 
| 121 | 
            +
                    """
         | 
| 122 | 
            +
                    b, s = ids.shape
         | 
| 123 | 
            +
                    mask = ids.ne(self.pad_id).long()
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                    # embeddings
         | 
| 126 | 
            +
                    x = self.token_embedding(ids) + \
         | 
| 127 | 
            +
                        self.type_embedding(torch.zeros_like(ids)) + \
         | 
| 128 | 
            +
                        self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
         | 
| 129 | 
            +
                    if self.post_norm:
         | 
| 130 | 
            +
                        x = self.norm(x)
         | 
| 131 | 
            +
                    x = self.dropout(x)
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                    # blocks
         | 
| 134 | 
            +
                    mask = torch.where(
         | 
| 135 | 
            +
                        mask.view(b, 1, 1, s).gt(0), 0.0,
         | 
| 136 | 
            +
                        torch.finfo(x.dtype).min)
         | 
| 137 | 
            +
                    for block in self.blocks:
         | 
| 138 | 
            +
                        x = block(x, mask)
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                    # output
         | 
| 141 | 
            +
                    if not self.post_norm:
         | 
| 142 | 
            +
                        x = self.norm(x)
         | 
| 143 | 
            +
                    return x
         | 
| 144 | 
            +
             | 
| 145 | 
            +
             | 
| 146 | 
            +
            def xlm_roberta_large(pretrained=False,
         | 
| 147 | 
            +
                                  return_tokenizer=False,
         | 
| 148 | 
            +
                                  device='cpu',
         | 
| 149 | 
            +
                                  **kwargs):
         | 
| 150 | 
            +
                """
         | 
| 151 | 
            +
                XLMRobertaLarge adapted from Huggingface.
         | 
| 152 | 
            +
                """
         | 
| 153 | 
            +
                # params
         | 
| 154 | 
            +
                cfg = dict(
         | 
| 155 | 
            +
                    vocab_size=250002,
         | 
| 156 | 
            +
                    max_seq_len=514,
         | 
| 157 | 
            +
                    type_size=1,
         | 
| 158 | 
            +
                    pad_id=1,
         | 
| 159 | 
            +
                    dim=1024,
         | 
| 160 | 
            +
                    num_heads=16,
         | 
| 161 | 
            +
                    num_layers=24,
         | 
| 162 | 
            +
                    post_norm=True,
         | 
| 163 | 
            +
                    dropout=0.1,
         | 
| 164 | 
            +
                    eps=1e-5)
         | 
| 165 | 
            +
                cfg.update(**kwargs)
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                # init a model on device
         | 
| 168 | 
            +
                with torch.device(device):
         | 
| 169 | 
            +
                    model = XLMRoberta(**cfg)
         | 
| 170 | 
            +
                return model
         | 
    	
        humo/utils/audio_processor_whisper.py
    ADDED
    
    | @@ -0,0 +1,173 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # pylint: disable=C0301
         | 
| 2 | 
            +
            '''
         | 
| 3 | 
            +
            This module contains the AudioProcessor class and related functions for processing audio data.
         | 
| 4 | 
            +
            It utilizes various libraries and models to perform tasks such as preprocessing, feature extraction,
         | 
| 5 | 
            +
            and audio separation. The class is initialized with configuration parameters and can process
         | 
| 6 | 
            +
            audio files using the provided models.
         | 
| 7 | 
            +
            '''
         | 
| 8 | 
            +
            import os
         | 
| 9 | 
            +
            import subprocess
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            import librosa
         | 
| 12 | 
            +
            import numpy as np
         | 
| 13 | 
            +
            import torch
         | 
| 14 | 
            +
            from audio_separator.separator import Separator
         | 
| 15 | 
            +
            from transformers import WhisperModel, AutoFeatureExtractor
         | 
| 16 | 
            +
            import torch.nn.functional as F
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            def linear_interpolation_fps(features, input_fps, output_fps, output_len=None):
         | 
| 20 | 
            +
                features = features.transpose(1, 2)  # [1, C, T]
         | 
| 21 | 
            +
                seq_len = features.shape[2] / float(input_fps)
         | 
| 22 | 
            +
                if output_len is None:
         | 
| 23 | 
            +
                    output_len = int(seq_len * output_fps)
         | 
| 24 | 
            +
                output_features = F.interpolate(features, size=output_len, align_corners=True, mode='linear')
         | 
| 25 | 
            +
                return output_features.transpose(1, 2)
         | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
            def resample_audio(input_audio_file: str, output_audio_file: str, sample_rate: int):
         | 
| 29 | 
            +
                p = subprocess.Popen([
         | 
| 30 | 
            +
                    "ffmpeg", "-y", "-v", "error", "-i", input_audio_file, "-ar", str(sample_rate), output_audio_file
         | 
| 31 | 
            +
                ])
         | 
| 32 | 
            +
                ret = p.wait()
         | 
| 33 | 
            +
                assert ret == 0, "Resample audio failed!"
         | 
| 34 | 
            +
                return output_audio_file
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            class AudioProcessor:
         | 
| 37 | 
            +
                """
         | 
| 38 | 
            +
                AudioProcessor is a class that handles the processing of audio files.
         | 
| 39 | 
            +
                It takes care of preprocessing the audio files, extracting features
         | 
| 40 | 
            +
                using wav2vec models, and separating audio signals if needed.
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                :param sample_rate: Sampling rate of the audio file
         | 
| 43 | 
            +
                :param fps: Frames per second for the extracted features
         | 
| 44 | 
            +
                :param wav2vec_model_path: Path to the wav2vec model
         | 
| 45 | 
            +
                :param only_last_features: Whether to only use the last features
         | 
| 46 | 
            +
                :param audio_separator_model_path: Path to the audio separator model
         | 
| 47 | 
            +
                :param audio_separator_model_name: Name of the audio separator model
         | 
| 48 | 
            +
                :param cache_dir: Directory to cache the intermediate results
         | 
| 49 | 
            +
                :param device: Device to run the processing on
         | 
| 50 | 
            +
                """
         | 
| 51 | 
            +
                def __init__(
         | 
| 52 | 
            +
                    self,
         | 
| 53 | 
            +
                    sample_rate,
         | 
| 54 | 
            +
                    fps,
         | 
| 55 | 
            +
                    wav2vec_model_path,
         | 
| 56 | 
            +
                    wav2vec_feature_type,
         | 
| 57 | 
            +
                    audio_separator_model_path:str=None,
         | 
| 58 | 
            +
                    audio_separator_model_name:str=None,
         | 
| 59 | 
            +
                    cache_dir:str='',
         | 
| 60 | 
            +
                    device="cuda:0",
         | 
| 61 | 
            +
                ) -> None:
         | 
| 62 | 
            +
                    self.sample_rate = sample_rate
         | 
| 63 | 
            +
                    self.fps = fps
         | 
| 64 | 
            +
                    self.device = device
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                    self.whisper = WhisperModel.from_pretrained(wav2vec_model_path).to(device).eval()
         | 
| 67 | 
            +
                    self.whisper.requires_grad_(False)
         | 
| 68 | 
            +
                    self.feature_extractor = AutoFeatureExtractor.from_pretrained(wav2vec_model_path)
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                    if audio_separator_model_name is not None:
         | 
| 71 | 
            +
                        try:
         | 
| 72 | 
            +
                            os.makedirs(cache_dir, exist_ok=True)
         | 
| 73 | 
            +
                        except OSError as _:
         | 
| 74 | 
            +
                            print("Fail to create the output cache dir.")
         | 
| 75 | 
            +
                        self.audio_separator = Separator(
         | 
| 76 | 
            +
                            output_dir=cache_dir,
         | 
| 77 | 
            +
                            output_single_stem="vocals",
         | 
| 78 | 
            +
                            model_file_dir=audio_separator_model_path,
         | 
| 79 | 
            +
                        )
         | 
| 80 | 
            +
                        self.audio_separator.load_model(audio_separator_model_name)
         | 
| 81 | 
            +
                        assert self.audio_separator.model_instance is not None, "Fail to load audio separate model."
         | 
| 82 | 
            +
                    else:
         | 
| 83 | 
            +
                        self.audio_separator=None
         | 
| 84 | 
            +
                        print("Use audio directly without vocals seperator.")        
         | 
| 85 | 
            +
             | 
| 86 | 
            +
             | 
| 87 | 
            +
                def get_audio_feature(self, audio_path):
         | 
| 88 | 
            +
                    audio_input, sampling_rate = librosa.load(audio_path, sr=16000)
         | 
| 89 | 
            +
                    assert sampling_rate == 16000
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                    audio_features = []
         | 
| 92 | 
            +
                    window = 750*640
         | 
| 93 | 
            +
                    for i in range(0, len(audio_input), window):
         | 
| 94 | 
            +
                        audio_feature = self.feature_extractor(audio_input[i:i+window], 
         | 
| 95 | 
            +
                                                        sampling_rate=sampling_rate, 
         | 
| 96 | 
            +
                                                        return_tensors="pt", 
         | 
| 97 | 
            +
                                                        ).input_features
         | 
| 98 | 
            +
                        audio_features.append(audio_feature)
         | 
| 99 | 
            +
                    audio_features = torch.cat(audio_features, dim=-1)
         | 
| 100 | 
            +
                    return audio_features, len(audio_input) // 640
         | 
| 101 | 
            +
             | 
| 102 | 
            +
             | 
| 103 | 
            +
                def preprocess(self, audio_path: str):
         | 
| 104 | 
            +
                    audio_input, audio_len = self.get_audio_feature(audio_path)
         | 
| 105 | 
            +
                    audio_feature = audio_input.to(self.whisper.device).float()
         | 
| 106 | 
            +
                    window = 3000
         | 
| 107 | 
            +
                    audio_prompts = []
         | 
| 108 | 
            +
                    for i in range(0, audio_feature.shape[-1], window):
         | 
| 109 | 
            +
                        audio_prompt = self.whisper.encoder(audio_feature[:,:,i:i+window], output_hidden_states=True).hidden_states
         | 
| 110 | 
            +
                        audio_prompt = torch.stack(audio_prompt, dim=2)
         | 
| 111 | 
            +
                        audio_prompts.append(audio_prompt)
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                    audio_prompts = torch.cat(audio_prompts, dim=1)
         | 
| 114 | 
            +
                    audio_prompts = audio_prompts[:,:audio_len*2]
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                    audio_emb = self.audio_emb_enc(audio_prompts, wav_enc_type="whisper")
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                    return audio_emb, audio_emb.shape[0]
         | 
| 119 | 
            +
                
         | 
| 120 | 
            +
                def audio_emb_enc(self, audio_emb, wav_enc_type="whisper"):
         | 
| 121 | 
            +
                    if wav_enc_type == "wav2vec":
         | 
| 122 | 
            +
                        feat_merge = audio_emb
         | 
| 123 | 
            +
                    elif wav_enc_type == "whisper":
         | 
| 124 | 
            +
                        # [1, T, 33, 1280]
         | 
| 125 | 
            +
                        feat0 = linear_interpolation_fps(audio_emb[:, :, 0: 8].mean(dim=2), 50, 25)
         | 
| 126 | 
            +
                        feat1 = linear_interpolation_fps(audio_emb[:, :, 8: 16].mean(dim=2), 50, 25)
         | 
| 127 | 
            +
                        feat2 = linear_interpolation_fps(audio_emb[:, :, 16: 24].mean(dim=2), 50, 25)
         | 
| 128 | 
            +
                        feat3 = linear_interpolation_fps(audio_emb[:, :, 24: 32].mean(dim=2), 50, 25)
         | 
| 129 | 
            +
                        feat4 = linear_interpolation_fps(audio_emb[:, :, 32], 50, 25)
         | 
| 130 | 
            +
                        feat_merge = torch.stack([feat0, feat1, feat2, feat3, feat4], dim=2)[0]  # [T, 5, 1280]
         | 
| 131 | 
            +
                    else:
         | 
| 132 | 
            +
                        raise ValueError(f"Unsupported wav_enc_type: {wav_enc_type}")
         | 
| 133 | 
            +
                    
         | 
| 134 | 
            +
                    return feat_merge
         | 
| 135 | 
            +
                
         | 
| 136 | 
            +
                def get_audio_emb_window(self, audio_emb, frame_num, frame0_idx, audio_shift=2):
         | 
| 137 | 
            +
                    zero_audio_embed = torch.zeros((audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device)
         | 
| 138 | 
            +
                    zero_audio_embed_3 = torch.zeros((3, audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device)  # device=audio_emb.device
         | 
| 139 | 
            +
                    iter_ = 1 + (frame_num - 1) // 4
         | 
| 140 | 
            +
                    audio_emb_wind = []
         | 
| 141 | 
            +
                    for lt_i in range(iter_):
         | 
| 142 | 
            +
                        if lt_i == 0:  # latent_i
         | 
| 143 | 
            +
                            # 提取第一帧VAElatent,audio左侧补0,标识出
         | 
| 144 | 
            +
                            st = frame0_idx + lt_i - 2
         | 
| 145 | 
            +
                            ed = frame0_idx + lt_i + 3
         | 
| 146 | 
            +
                            wind_feat = torch.stack([
         | 
| 147 | 
            +
                                audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed
         | 
| 148 | 
            +
                                for i in range(st, ed)
         | 
| 149 | 
            +
                            ], dim=0)  # [5, 13, 768]
         | 
| 150 | 
            +
                            wind_feat = torch.cat((zero_audio_embed_3, wind_feat), dim=0)  # [8, 13, 768]
         | 
| 151 | 
            +
                        else:
         | 
| 152 | 
            +
                            st = frame0_idx + 1 + 4 * (lt_i - 1) - audio_shift
         | 
| 153 | 
            +
                            ed = frame0_idx + 1 + 4 * lt_i + audio_shift
         | 
| 154 | 
            +
                            wind_feat = torch.stack([
         | 
| 155 | 
            +
                                audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed
         | 
| 156 | 
            +
                                for i in range(st, ed)
         | 
| 157 | 
            +
                            ], dim=0)  # [8, 13, 768]
         | 
| 158 | 
            +
                        audio_emb_wind.append(wind_feat)
         | 
| 159 | 
            +
                    audio_emb_wind = torch.stack(audio_emb_wind, dim=0)  # [iter_, 8, 13, 768]
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                    return audio_emb_wind, ed - audio_shift
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                def close(self):
         | 
| 164 | 
            +
                    """
         | 
| 165 | 
            +
                    TODO: to be implemented
         | 
| 166 | 
            +
                    """
         | 
| 167 | 
            +
                    return self
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                def __enter__(self):
         | 
| 170 | 
            +
                    return self
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                def __exit__(self, _exc_type, _exc_val, _exc_tb):
         | 
| 173 | 
            +
                    self.close()
         | 
    	
        humo/utils/wav2vec.py
    ADDED
    
    | @@ -0,0 +1,218 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # pylint: disable=R0901
         | 
| 2 | 
            +
            # src/models/wav2vec.py
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            """
         | 
| 5 | 
            +
            This module defines the Wav2Vec model, which is a pre-trained model for speech recognition and understanding.
         | 
| 6 | 
            +
            It inherits from the Wav2Vec2Model class in the transformers library and provides additional functionalities
         | 
| 7 | 
            +
            such as feature extraction and encoding.
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            Classes:
         | 
| 10 | 
            +
                Wav2VecModel: Inherits from Wav2Vec2Model and adds additional methods for feature extraction and encoding.
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            Functions:
         | 
| 13 | 
            +
                linear_interpolation: Interpolates the features based on the sequence length.
         | 
| 14 | 
            +
            """
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            import torch.nn.functional as F
         | 
| 17 | 
            +
            from transformers import Wav2Vec2Model
         | 
| 18 | 
            +
            from transformers.modeling_outputs import BaseModelOutput
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            class Wav2VecModel(Wav2Vec2Model):
         | 
| 22 | 
            +
                """
         | 
| 23 | 
            +
                Wav2VecModel is a custom model class that extends the Wav2Vec2Model class from the transformers library. 
         | 
| 24 | 
            +
                It inherits all the functionality of the Wav2Vec2Model and adds additional methods for feature extraction and encoding.
         | 
| 25 | 
            +
                ...
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                Attributes:
         | 
| 28 | 
            +
                    base_model (Wav2Vec2Model): The base Wav2Vec2Model object.
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                Methods:
         | 
| 31 | 
            +
                    forward(input_values, seq_len, attention_mask=None, mask_time_indices=None
         | 
| 32 | 
            +
                    , output_attentions=None, output_hidden_states=None, return_dict=None):
         | 
| 33 | 
            +
                        Forward pass of the Wav2VecModel. 
         | 
| 34 | 
            +
                        It takes input_values, seq_len, and other optional parameters as input and returns the output of the base model.
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                    feature_extract(input_values, seq_len):
         | 
| 37 | 
            +
                        Extracts features from the input_values using the base model.
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                    encode(extract_features, attention_mask=None, mask_time_indices=None, output_attentions=None, output_hidden_states=None, return_dict=None):
         | 
| 40 | 
            +
                        Encodes the extracted features using the base model and returns the encoded features.
         | 
| 41 | 
            +
                """
         | 
| 42 | 
            +
                def forward(
         | 
| 43 | 
            +
                    self,
         | 
| 44 | 
            +
                    input_values,
         | 
| 45 | 
            +
                    seq_len,
         | 
| 46 | 
            +
                    attention_mask=None,
         | 
| 47 | 
            +
                    mask_time_indices=None,
         | 
| 48 | 
            +
                    output_attentions=None,
         | 
| 49 | 
            +
                    output_hidden_states=None,
         | 
| 50 | 
            +
                    return_dict=None,
         | 
| 51 | 
            +
                ):
         | 
| 52 | 
            +
                    """
         | 
| 53 | 
            +
                    Forward pass of the Wav2Vec model.
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                    Args:
         | 
| 56 | 
            +
                        self: The instance of the model.
         | 
| 57 | 
            +
                        input_values: The input values (waveform) to the model.
         | 
| 58 | 
            +
                        seq_len: The sequence length of the input values.
         | 
| 59 | 
            +
                        attention_mask: Attention mask to be used for the model.
         | 
| 60 | 
            +
                        mask_time_indices: Mask indices to be used for the model.
         | 
| 61 | 
            +
                        output_attentions: If set to True, returns attentions.
         | 
| 62 | 
            +
                        output_hidden_states: If set to True, returns hidden states.
         | 
| 63 | 
            +
                        return_dict: If set to True, returns a BaseModelOutput instead of a tuple.
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    Returns:
         | 
| 66 | 
            +
                        The output of the Wav2Vec model.
         | 
| 67 | 
            +
                    """
         | 
| 68 | 
            +
                    self.config.output_attentions = True
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                    output_hidden_states = (
         | 
| 71 | 
            +
                        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
         | 
| 72 | 
            +
                    )
         | 
| 73 | 
            +
                    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                    extract_features = self.feature_extractor(input_values)
         | 
| 76 | 
            +
                    extract_features = extract_features.transpose(1, 2)
         | 
| 77 | 
            +
                    extract_features = linear_interpolation(extract_features, seq_len=seq_len)
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                    if attention_mask is not None:
         | 
| 80 | 
            +
                        # compute reduced attention_mask corresponding to feature vectors
         | 
| 81 | 
            +
                        attention_mask = self._get_feature_vector_attention_mask(
         | 
| 82 | 
            +
                            extract_features.shape[1], attention_mask, add_adapter=False
         | 
| 83 | 
            +
                        )
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                    hidden_states, extract_features = self.feature_projection(extract_features)
         | 
| 86 | 
            +
                    hidden_states = self._mask_hidden_states(
         | 
| 87 | 
            +
                        hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
         | 
| 88 | 
            +
                    )
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                    encoder_outputs = self.encoder(
         | 
| 91 | 
            +
                        hidden_states,
         | 
| 92 | 
            +
                        attention_mask=attention_mask,
         | 
| 93 | 
            +
                        output_attentions=output_attentions,
         | 
| 94 | 
            +
                        output_hidden_states=output_hidden_states,
         | 
| 95 | 
            +
                        return_dict=return_dict,
         | 
| 96 | 
            +
                    )
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                    hidden_states = encoder_outputs[0]
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                    if self.adapter is not None:
         | 
| 101 | 
            +
                        hidden_states = self.adapter(hidden_states)
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                    if not return_dict:
         | 
| 104 | 
            +
                        return (hidden_states, ) + encoder_outputs[1:]
         | 
| 105 | 
            +
                    return BaseModelOutput(
         | 
| 106 | 
            +
                        last_hidden_state=hidden_states,
         | 
| 107 | 
            +
                        hidden_states=encoder_outputs.hidden_states,
         | 
| 108 | 
            +
                        attentions=encoder_outputs.attentions,
         | 
| 109 | 
            +
                    )
         | 
| 110 | 
            +
             | 
| 111 | 
            +
             | 
| 112 | 
            +
                def feature_extract(
         | 
| 113 | 
            +
                    self,
         | 
| 114 | 
            +
                    input_values,
         | 
| 115 | 
            +
                    seq_len,
         | 
| 116 | 
            +
                ):
         | 
| 117 | 
            +
                    """
         | 
| 118 | 
            +
                    Extracts features from the input values and returns the extracted features.
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                    Parameters:
         | 
| 121 | 
            +
                    input_values (torch.Tensor): The input values to be processed.
         | 
| 122 | 
            +
                    seq_len (torch.Tensor): The sequence lengths of the input values.
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                    Returns:
         | 
| 125 | 
            +
                    extracted_features (torch.Tensor): The extracted features from the input values.
         | 
| 126 | 
            +
                    """
         | 
| 127 | 
            +
                    extract_features = self.feature_extractor(input_values)
         | 
| 128 | 
            +
                    extract_features = extract_features.transpose(1, 2)
         | 
| 129 | 
            +
                    extract_features = linear_interpolation(extract_features, seq_len=seq_len)
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                    return extract_features
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                def encode(
         | 
| 134 | 
            +
                    self,
         | 
| 135 | 
            +
                    extract_features,
         | 
| 136 | 
            +
                    attention_mask=None,
         | 
| 137 | 
            +
                    mask_time_indices=None,
         | 
| 138 | 
            +
                    output_attentions=None,
         | 
| 139 | 
            +
                    output_hidden_states=None,
         | 
| 140 | 
            +
                    return_dict=None,
         | 
| 141 | 
            +
                ):
         | 
| 142 | 
            +
                    """
         | 
| 143 | 
            +
                    Encodes the input features into the output space.
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                    Args:
         | 
| 146 | 
            +
                        extract_features (torch.Tensor): The extracted features from the audio signal.
         | 
| 147 | 
            +
                        attention_mask (torch.Tensor, optional): Attention mask to be used for padding.
         | 
| 148 | 
            +
                        mask_time_indices (torch.Tensor, optional): Masked indices for the time dimension.
         | 
| 149 | 
            +
                        output_attentions (bool, optional): If set to True, returns the attention weights.
         | 
| 150 | 
            +
                        output_hidden_states (bool, optional): If set to True, returns all hidden states.
         | 
| 151 | 
            +
                        return_dict (bool, optional): If set to True, returns a BaseModelOutput instead of the tuple.
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                    Returns:
         | 
| 154 | 
            +
                        The encoded output features.
         | 
| 155 | 
            +
                    """
         | 
| 156 | 
            +
                    self.config.output_attentions = True
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                    output_hidden_states = (
         | 
| 159 | 
            +
                        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
         | 
| 160 | 
            +
                    )
         | 
| 161 | 
            +
                    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                    if attention_mask is not None:
         | 
| 164 | 
            +
                        # compute reduced attention_mask corresponding to feature vectors
         | 
| 165 | 
            +
                        attention_mask = self._get_feature_vector_attention_mask(
         | 
| 166 | 
            +
                            extract_features.shape[1], attention_mask, add_adapter=False
         | 
| 167 | 
            +
                        )
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                    hidden_states, extract_features = self.feature_projection(extract_features)
         | 
| 170 | 
            +
                    hidden_states = self._mask_hidden_states(
         | 
| 171 | 
            +
                        hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
         | 
| 172 | 
            +
                    )
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                    encoder_outputs = self.encoder(
         | 
| 175 | 
            +
                        hidden_states,
         | 
| 176 | 
            +
                        attention_mask=attention_mask,
         | 
| 177 | 
            +
                        output_attentions=output_attentions,
         | 
| 178 | 
            +
                        output_hidden_states=output_hidden_states,
         | 
| 179 | 
            +
                        return_dict=return_dict,
         | 
| 180 | 
            +
                    )
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                    hidden_states = encoder_outputs[0]
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                    if self.adapter is not None:
         | 
| 185 | 
            +
                        hidden_states = self.adapter(hidden_states)
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                    if not return_dict:
         | 
| 188 | 
            +
                        return (hidden_states, ) + encoder_outputs[1:]
         | 
| 189 | 
            +
                    return BaseModelOutput(
         | 
| 190 | 
            +
                        last_hidden_state=hidden_states,
         | 
| 191 | 
            +
                        hidden_states=encoder_outputs.hidden_states,
         | 
| 192 | 
            +
                        attentions=encoder_outputs.attentions,
         | 
| 193 | 
            +
                    )
         | 
| 194 | 
            +
             | 
| 195 | 
            +
             | 
| 196 | 
            +
            def linear_interpolation(features, seq_len):
         | 
| 197 | 
            +
                """
         | 
| 198 | 
            +
                Transpose the features to interpolate linearly.
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                Args:
         | 
| 201 | 
            +
                    features (torch.Tensor): The extracted features to be interpolated.
         | 
| 202 | 
            +
                    seq_len (torch.Tensor): The sequence lengths of the features.
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                Returns:
         | 
| 205 | 
            +
                    torch.Tensor: The interpolated features.
         | 
| 206 | 
            +
                """
         | 
| 207 | 
            +
                features = features.transpose(1, 2)
         | 
| 208 | 
            +
                output_features = F.interpolate(features, size=seq_len, align_corners=True, mode='linear')
         | 
| 209 | 
            +
                return output_features.transpose(1, 2)
         | 
| 210 | 
            +
             | 
| 211 | 
            +
             | 
| 212 | 
            +
            def linear_interpolation_fps(features, input_fps, output_fps, output_len=None):
         | 
| 213 | 
            +
                features = features.transpose(1, 2)  # [1, C, T]
         | 
| 214 | 
            +
                seq_len = features.shape[2] / float(input_fps)
         | 
| 215 | 
            +
                if output_len is None:
         | 
| 216 | 
            +
                    output_len = int(seq_len * output_fps)
         | 
| 217 | 
            +
                output_features = F.interpolate(features, size=output_len, align_corners=True, mode='linear')
         | 
| 218 | 
            +
                return output_features.transpose(1, 2)
         | 
    	
        main.py
    ADDED
    
    | @@ -0,0 +1,28 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
         | 
| 2 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 3 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 4 | 
            +
            # You may obtain a copy of the License at
         | 
| 5 | 
            +
            # http://www.apache.org/licenses/LICENSE-2.0
         | 
| 6 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 7 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 8 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 9 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 10 | 
            +
            # limitations under the License.
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            # Inference codes adapted from [SeedVR]
         | 
| 13 | 
            +
            # https://github.com/ByteDance-Seed/SeedVR/blob/main/projects/inference_seedvr2_7b.py
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            from sys import argv
         | 
| 16 | 
            +
            import sys
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            path_to_insert = "humo"
         | 
| 19 | 
            +
            if path_to_insert not in sys.path:
         | 
| 20 | 
            +
                sys.path.insert(0, path_to_insert)
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            from common.config import load_config, create_object
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            # Load config.
         | 
| 25 | 
            +
            config = load_config(argv[1], argv[2:])
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            runner = create_object(config)
         | 
| 28 | 
            +
            runner.entrypoint()
         | 
