Duplicate from tencent/Hunyuan3D-2.1
Browse filesCo-authored-by: huiwenshi <[email protected]>
- .gitattributes +35 -0
- LICENSE +82 -0
- Notice.txt +122 -0
- README.md +78 -0
- demo.py +47 -0
- hunyuan3d-dit-v2-1/config.yaml +82 -0
- hunyuan3d-dit-v2-1/model.fp16.ckpt +3 -0
- hunyuan3d-paintpbr-v2-1/README.md +53 -0
- hunyuan3d-paintpbr-v2-1/feature_extractor/preprocessor_config.json +20 -0
- hunyuan3d-paintpbr-v2-1/image_encoder/config.json +23 -0
- hunyuan3d-paintpbr-v2-1/image_encoder/model.safetensors +3 -0
- hunyuan3d-paintpbr-v2-1/model_index.json +37 -0
- hunyuan3d-paintpbr-v2-1/scheduler/scheduler_config.json +15 -0
- hunyuan3d-paintpbr-v2-1/text_encoder/config.json +25 -0
- hunyuan3d-paintpbr-v2-1/text_encoder/pytorch_model.bin +3 -0
- hunyuan3d-paintpbr-v2-1/tokenizer/merges.txt +0 -0
- hunyuan3d-paintpbr-v2-1/tokenizer/special_tokens_map.json +24 -0
- hunyuan3d-paintpbr-v2-1/tokenizer/tokenizer_config.json +34 -0
- hunyuan3d-paintpbr-v2-1/tokenizer/vocab.json +0 -0
- hunyuan3d-paintpbr-v2-1/unet/attn_processor.py +839 -0
- hunyuan3d-paintpbr-v2-1/unet/config.json +45 -0
- hunyuan3d-paintpbr-v2-1/unet/diffusion_pytorch_model.bin +3 -0
- hunyuan3d-paintpbr-v2-1/unet/model.py +622 -0
- hunyuan3d-paintpbr-v2-1/unet/modules.py +1102 -0
- hunyuan3d-paintpbr-v2-1/vae/config.json +29 -0
- hunyuan3d-paintpbr-v2-1/vae/diffusion_pytorch_model.bin +3 -0
- hunyuan3d-vae-v2-1/config.yaml +19 -0
- hunyuan3d-vae-v2-1/model.fp16.ckpt +3 -0
- hy3dpaint/textureGenPipeline.py +192 -0
- hy3dpaint/utils/multiview_utils.py +128 -0
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TENCENT HUNYUAN 3D 2.1 COMMUNITY LICENSE AGREEMENT
|
| 2 |
+
Tencent Hunyuan 3D 2.1 Release Date: June 13, 2025
|
| 3 |
+
THIS LICENSE AGREEMENT DOES NOT APPLY IN THE EUROPEAN UNION, UNITED KINGDOM AND SOUTH KOREA AND IS EXPRESSLY LIMITED TO THE TERRITORY, AS DEFINED BELOW.
|
| 4 |
+
By clicking to agree or by using, reproducing, modifying, distributing, performing or displaying any portion or element of the Tencent Hunyuan 3D 2.1 Works, including via any Hosted Service, You will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately.
|
| 5 |
+
1. DEFINITIONS.
|
| 6 |
+
a. “Acceptable Use Policy” shall mean the policy made available by Tencent as set forth in the Exhibit A.
|
| 7 |
+
b. “Agreement” shall mean the terms and conditions for use, reproduction, distribution, modification, performance and displaying of Tencent Hunyuan 3D 2.1 Works or any portion or element thereof set forth herein.
|
| 8 |
+
c. “Documentation” shall mean the specifications, manuals and documentation for Tencent Hunyuan 3D 2.1 made publicly available by Tencent.
|
| 9 |
+
d. “Hosted Service” shall mean a hosted service offered via an application programming interface (API), web access, or any other electronic or remote means.
|
| 10 |
+
e. “Licensee,” “You” or “Your” shall mean a natural person or legal entity exercising the rights granted by this Agreement and/or using the Tencent Hunyuan 3D 2.1 Works for any purpose and in any field of use.
|
| 11 |
+
f. “Materials” shall mean, collectively, Tencent’s proprietary Tencent Hunyuan 3D 2.1 and Documentation (and any portion thereof) as made available by Tencent under this Agreement.
|
| 12 |
+
g. “Model Derivatives” shall mean all: (i) modifications to Tencent Hunyuan 3D 2.1 or any Model Derivative of Tencent Hunyuan 3D 2.1; (ii) works based on Tencent Hunyuan 3D 2.1 or any Model Derivative of Tencent Hunyuan 3D 2.1; or (iii) any other machine learning model which is created by transfer of patterns of the weights, parameters, operations, or Output of Tencent Hunyuan 3D 2.1 or any Model Derivative of Tencent Hunyuan 3D 2.1, to that model in order to cause that model to perform similarly to Tencent Hunyuan 3D 2.1 or a Model Derivative of Tencent Hunyuan 3D 2.1, including distillation methods, methods that use intermediate data representations, or methods based on the generation of synthetic data Outputs by Tencent Hunyuan 3D 2.1 or a Model Derivative of Tencent Hunyuan 3D 2.1 for training that model. For clarity, Outputs by themselves are not deemed Model Derivatives.
|
| 13 |
+
h. “Output” shall mean the information and/or content output of Tencent Hunyuan 3D 2.1 or a Model Derivative that results from operating or otherwise using Tencent Hunyuan 3D 2.1 or a Model Derivative, including via a Hosted Service.
|
| 14 |
+
i. “Tencent,” “We” or “Us” shall mean the applicable entity or entities in the Tencent corporate family that own(s) intellectual property or other rights embodied in or utilized by the Materials.
|
| 15 |
+
* Section 1.i of the previous Hunyuan License Agreement defined “Tencent,” “We” or “Us” to mean THL A29 Limited, and the copyright notices pertaining to the Materials were previously in the name of “THL A29 Limited.” That entity has now been de-registered. You should treat all previously distributed copies of the Materials as if Section 1.i of the Agreement defined “Tencent,” “We” or “Us” to mean “the applicable entity or entities in the Tencent corporate family that own(s) intellectual property or other rights embodied in or utilized by the Materials,” and treat the copyright notice(s) accompanying the Materials as if they were in the name of “Tencent.” When providing a copy of any Agreement to Third Party recipients of the Tencent Hunyuan Works or products or services using them, as required by Section 3.a of the Agreement, you should provide the most current version of the Agreement, including the change of definition in Section 1.i of the Agreement.
|
| 16 |
+
j. “Tencent Hunyuan 3D 2.1” shall mean the 3D generation models and their software and algorithms, including trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing made publicly available by Us at [ https://github.com/Tencent-Hunyuan/Hunyuan3D-2.1].
|
| 17 |
+
k. “Tencent Hunyuan 3D 2.1 Works” shall mean: (i) the Materials; (ii) Model Derivatives; and (iii) all derivative works thereof.
|
| 18 |
+
l. “Territory” shall mean the worldwide territory, excluding the territory of the European Union, United Kingdom and South Korea.
|
| 19 |
+
m. “Third Party” or “Third Parties” shall mean individuals or legal entities that are not under common control with Us or You.
|
| 20 |
+
n. “including” shall mean including but not limited to.
|
| 21 |
+
2. GRANT OF RIGHTS.
|
| 22 |
+
We grant You, for the Territory only, a non-exclusive, non-transferable and royalty-free limited license under Tencent’s intellectual property or other rights owned by Us embodied in or utilized by the Materials to use, reproduce, distribute, create derivative works of (including Model Derivatives), and make modifications to the Materials, only in accordance with the terms of this Agreement and the Acceptable Use Policy, and You must not violate (or encourage or permit anyone else to violate) any term of this Agreement or the Acceptable Use Policy.
|
| 23 |
+
3. DISTRIBUTION.
|
| 24 |
+
You may, subject to Your compliance with this Agreement, distribute or make available to Third Parties the Tencent Hunyuan 3D 2.1 Works, exclusively in the Territory, provided that You meet all of the following conditions:
|
| 25 |
+
a. You must provide all such Third Party recipients of the Tencent Hunyuan 3D 2.1 Works or products or services using them a copy of this Agreement;
|
| 26 |
+
b. You must cause any modified files to carry prominent notices stating that You changed the files;
|
| 27 |
+
c. You are encouraged to: (i) publish at least one technology introduction blogpost or one public statement expressing Your experience of using the Tencent Hunyuan 3D 2.1 Works; and (ii) mark the products or services developed by using the Tencent Hunyuan 3D 2.1 Works to indicate that the product/service is “Powered by Tencent Hunyuan”; and
|
| 28 |
+
d. All distributions to Third Parties (other than through a Hosted Service) must be accompanied by a “Notice” text file that contains the following notice: “Tencent Hunyuan 3D 2.1 is licensed under the Tencent Hunyuan 3D 2.1 Community License Agreement, Copyright © 2025 Tencent. All Rights Reserved. The trademark rights of “Tencent Hunyuan” are owned by Tencent or its affiliate.”
|
| 29 |
+
You may add Your own copyright statement to Your modifications and, except as set forth in this Section and in Section 5, may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Model Derivatives as a whole, provided Your use, reproduction, modification, distribution, performance and display of the work otherwise complies with the terms and conditions of this Agreement (including as regards the Territory). If You receive Tencent Hunyuan 3D 2.1 Works from a Licensee as part of an integrated end user product, then this Section 3 of this Agreement will not apply to You.
|
| 30 |
+
4. ADDITIONAL COMMERCIAL TERMS.
|
| 31 |
+
If, on the Tencent Hunyuan 3D 2.1 version release date, the monthly active users of all products or services made available by or for Licensee is greater than 1 million monthly active users in the preceding calendar month, You must request a license from Tencent, which Tencent may grant to You in its sole discretion, and You are not authorized to exercise any of the rights under this Agreement unless or until Tencent otherwise expressly grants You such rights.
|
| 32 |
+
Subject to Tencent's written approval, you may request a license for the use of Tencent Hunyuan 3D 2.1 by submitting the following information to [email protected]:
|
| 33 |
+
a. Your company’s name and associated business sector that plans to use Tencent Hunyuan 3D 2.1.
|
| 34 |
+
b. Your intended use case and the purpose of using Tencent Hunyuan 3D 2.1.
|
| 35 |
+
c. Your plans to modify Tencent Hunyuan 3D 2.1 or create Model Derivatives.
|
| 36 |
+
5. RULES OF USE.
|
| 37 |
+
a. Your use of the Tencent Hunyuan 3D 2.1 Works must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the Acceptable Use Policy for the Tencent Hunyuan 3D 2.1 Works, which is hereby incorporated by reference into this Agreement. You must include the use restrictions referenced in these Sections 5(a) and 5(b) as an enforceable provision in any agreement (e.g., license agreement, terms of use, etc.) governing the use and/or distribution of Tencent Hunyuan 3D 2.1 Works and You must provide notice to subsequent users to whom You distribute that Tencent Hunyuan 3D 2.1 Works are subject to the use restrictions in these Sections 5(a) and 5(b).
|
| 38 |
+
b. You must not use the Tencent Hunyuan 3D 2.1 Works or any Output or results of the Tencent Hunyuan 3D 2.1 Works to improve any other AI model (other than Tencent Hunyuan 3D 2.1 or Model Derivatives thereof).
|
| 39 |
+
c. You must not use, reproduce, modify, distribute, or display the Tencent Hunyuan 3D 2.1 Works, Output or results of the Tencent Hunyuan 3D 2.1 Works outside the Territory. Any such use outside the Territory is unlicensed and unauthorized under this Agreement.
|
| 40 |
+
6. INTELLECTUAL PROPERTY.
|
| 41 |
+
a. Subject to Tencent’s ownership of Tencent Hunyuan 3D 2.1 Works made by or for Tencent and intellectual property rights therein, conditioned upon Your compliance with the terms and conditions of this Agreement, as between You and Tencent, You will be the owner of any derivative works and modifications of the Materials and any Model Derivatives that are made by or for You.
|
| 42 |
+
b. No trademark licenses are granted under this Agreement, and in connection with the Tencent Hunyuan 3D 2.1 Works, Licensee may not use any name or mark owned by or associated with Tencent or any of its affiliates, except as required for reasonable and customary use in describing and distributing the Tencent Hunyuan 3D 2.1 Works. Tencent hereby grants You a license to use “Tencent Hunyuan” (the “Mark”) in the Territory solely as required to comply with the provisions of Section 3(c), provided that You comply with any applicable laws related to trademark protection. All goodwill arising out of Your use of the Mark will inure to the benefit of Tencent.
|
| 43 |
+
c. If You commence a lawsuit or other proceedings (including a cross-claim or counterclaim in a lawsuit) against Us or any person or entity alleging that the Materials or any Output, or any portion of any of the foregoing, infringe any intellectual property or other right owned or licensable by You, then all licenses granted to You under this Agreement shall terminate as of the date such lawsuit or other proceeding is filed. You will defend, indemnify and hold harmless Us from and against any claim by any Third Party arising out of or related to Your or the Third Party’s use or distribution of the Tencent Hunyuan 3D 2.1 Works.
|
| 44 |
+
d. Tencent claims no rights in Outputs You generate. You and Your users are solely responsible for Outputs and their subsequent uses.
|
| 45 |
+
7. DISCLAIMERS OF WARRANTY AND LIMITATIONS OF LIABILITY.
|
| 46 |
+
a. We are not obligated to support, update, provide training for, or develop any further version of the Tencent Hunyuan 3D 2.1 Works or to grant any license thereto.
|
| 47 |
+
b. UNLESS AND ONLY TO THE EXTENT REQUIRED BY APPLICABLE LAW, THE TENCENT HUNYUAN 3D 2.1 WORKS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED “AS IS” WITHOUT ANY EXPRESS OR IMPLIED WARRANTIES OF ANY KIND INCLUDING ANY WARRANTIES OF TITLE, MERCHANTABILITY, NONINFRINGEMENT, COURSE OF DEALING, USAGE OF TRADE, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING, REPRODUCING, MODIFYING, PERFORMING, DISPLAYING OR DISTRIBUTING ANY OF THE TENCENT HUNYUAN 3D 2.1 WORKS OR OUTPUTS AND ASSUME ANY AND ALL RISKS ASSOCIATED WITH YOUR OR A THIRD PARTY’S USE OR DISTRIBUTION OF ANY OF THE TENCENT HUNYUAN 3D 2.1 WORKS OR OUTPUTS AND YOUR EXERCISE OF RIGHTS AND PERMISSIONS UNDER THIS AGREEMENT.
|
| 48 |
+
c. TO THE FULLEST EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT SHALL TENCENT OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, FOR ANY DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, CONSEQUENTIAL OR PUNITIVE DAMAGES, OR LOST PROFITS OF ANY KIND ARISING FROM THIS AGREEMENT OR RELATED TO ANY OF THE TENCENT HUNYUAN 3D 2.1 WORKS OR OUTPUTS, EVEN IF TENCENT OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
|
| 49 |
+
8. SURVIVAL AND TERMINATION.
|
| 50 |
+
a. The term of this Agreement shall commence upon Your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein.
|
| 51 |
+
b. We may terminate this Agreement if You breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, You must promptly delete and cease use of the Tencent Hunyuan 3D 2.1 Works. Sections 6(a), 6(c), 7 and 9 shall survive the termination of this Agreement.
|
| 52 |
+
9. GOVERNING LAW AND JURISDICTION.
|
| 53 |
+
a. This Agreement and any dispute arising out of or relating to it will be governed by the laws of the Hong Kong Special Administrative Region of the People’s Republic of China, without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement.
|
| 54 |
+
b. Exclusive jurisdiction and venue for any dispute arising out of or relating to this Agreement will be a court of competent jurisdiction in the Hong Kong Special Administrative Region of the People’s Republic of China, and Tencent and Licensee consent to the exclusive jurisdiction of such court with respect to any such dispute.
|
| 55 |
+
|
| 56 |
+
EXHIBIT A
|
| 57 |
+
ACCEPTABLE USE POLICY
|
| 58 |
+
|
| 59 |
+
Tencent reserves the right to update this Acceptable Use Policy from time to time.
|
| 60 |
+
Last modified: November 5, 2024
|
| 61 |
+
|
| 62 |
+
Tencent endeavors to promote safe and fair use of its tools and features, including Tencent Hunyuan 3D 2.1. You agree not to use Tencent Hunyuan 3D 2.1 or Model Derivatives:
|
| 63 |
+
1. Outside the Territory;
|
| 64 |
+
2. In any way that violates any applicable national, federal, state, local, international or any other law or regulation;
|
| 65 |
+
3. To harm Yourself or others;
|
| 66 |
+
4. To repurpose or distribute output from Tencent Hunyuan 3D 2.1 or any Model Derivatives to harm Yourself or others;
|
| 67 |
+
5. To override or circumvent the safety guardrails and safeguards We have put in place;
|
| 68 |
+
6. For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
|
| 69 |
+
7. To generate or disseminate verifiably false information and/or content with the purpose of harming others or influencing elections;
|
| 70 |
+
8. To generate or facilitate false online engagement, including fake reviews and other means of fake online engagement;
|
| 71 |
+
9. To intentionally defame, disparage or otherwise harass others;
|
| 72 |
+
10. To generate and/or disseminate malware (including ransomware) or any other content to be used for the purpose of harming electronic systems;
|
| 73 |
+
11. To generate or disseminate personal identifiable information with the purpose of harming others;
|
| 74 |
+
12. To generate or disseminate information (including images, code, posts, articles), and place the information in any public context (including –through the use of bot generated tweets), without expressly and conspicuously identifying that the information and/or content is machine generated;
|
| 75 |
+
13. To impersonate another individual without consent, authorization, or legal right;
|
| 76 |
+
14. To make high-stakes automated decisions in domains that affect an individual’s safety, rights or wellbeing (e.g., law enforcement, migration, medicine/health, management of critical infrastructure, safety components of products, essential services, credit, employment, housing, education, social scoring, or insurance);
|
| 77 |
+
15. In a manner that violates or disrespects the social ethics and moral standards of other countries or regions;
|
| 78 |
+
16. To perform, facilitate, threaten, incite, plan, promote or encourage violent extremism or terrorism;
|
| 79 |
+
17. For any use intended to discriminate against or harm individuals or groups based on protected characteristics or categories, online or offline social behavior or known or predicted personal or personality characteristics;
|
| 80 |
+
18. To intentionally exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
|
| 81 |
+
19. For military purposes;
|
| 82 |
+
20. To engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or other professional practices.
|
Notice.txt
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Usage and Legal Notices:
|
| 2 |
+
|
| 3 |
+
Tencent is pleased to support the open source community by making Hunyuan 3D 2.1 available.
|
| 4 |
+
|
| 5 |
+
Copyright (C) 2025 Tencent. All rights reserved. The below software and/or models in this distribution may have been modified by Tencent ("Tencent Modifications"). All Tencent Modifications are Copyright (C) Tencent.
|
| 6 |
+
|
| 7 |
+
Hunyuan 3D 2.1 is licensed under the TENCENT HUNYUAN 3D 2.1 COMMUNITY LICENSE AGREEMENT except for the third-party components listed below, which is licensed under different terms. Hunyuan 3D 2.1 does not impose any additional limitations beyond what is outlined in the respective licenses of these third-party components. Users must comply with all terms and conditions of original licenses of these third-party components and must ensure that the usage of the third party components adheres to all relevant laws and regulations.
|
| 8 |
+
|
| 9 |
+
For avoidance of doubts, Hunyuan 3D 2.1 means inference-enabling code, parameters, and weights of this Model only, which are made publicly available by Tencent in accordance with TENCENT HUNYUAN 3D 2.1 COMMUNITY LICENSE AGREEMENT.
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
Other dependencies and licenses:
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
Open Source Model Licensed under the MIT and CreativeML Open RAIL++-M License:
|
| 16 |
+
--------------------------------------------------------------------
|
| 17 |
+
1. Stable Diffusion
|
| 18 |
+
Copyright (c) 2022 Stability AI
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
Terms of the MIT and CreativeML Open RAIL++-M License:
|
| 22 |
+
--------------------------------------------------------------------
|
| 23 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 24 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 25 |
+
in the Software without restriction, including without limitation the rights
|
| 26 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 27 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 28 |
+
furnished to do so, subject to the following conditions:
|
| 29 |
+
|
| 30 |
+
The above copyright notice and this permission notice shall be included in all
|
| 31 |
+
copies or substantial portions of the Software.
|
| 32 |
+
|
| 33 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 34 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 35 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 36 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 37 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 38 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 39 |
+
SOFTWARE.
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
CreativeML Open RAIL++-M License
|
| 43 |
+
dated November 24, 2022
|
| 44 |
+
|
| 45 |
+
Section I: PREAMBLE
|
| 46 |
+
|
| 47 |
+
Multimodal generative models are being widely adopted and used, and have the potential to transform the way artists, among other individuals, conceive and benefit from AI or ML technologies as a tool for content creation.
|
| 48 |
+
|
| 49 |
+
Notwithstanding the current and potential benefits that these artifacts can bring to society at large, there are also concerns about potential misuses of them, either due to their technical limitations or ethical considerations.
|
| 50 |
+
|
| 51 |
+
In short, this license strives for both the open and responsible downstream use of the accompanying model. When it comes to the open character, we took inspiration from open source permissive licenses regarding the grant of IP rights. Referring to the downstream responsible use, we added use-based restrictions not permitting the use of the Model in very specific scenarios, in order for the licensor to be able to enforce the license in case potential misuses of the Model may occur. At the same time, we strive to promote open and responsible research on generative models for art and content generation.
|
| 52 |
+
|
| 53 |
+
Even though downstream derivative versions of the model could be released under different licensing terms, the latter will always have to include - at minimum - the same use-based restrictions as the ones in the original license (this license). We believe in the intersection between open and responsible AI development; thus, this License aims to strike a balance between both in order to enable responsible open-science in the field of AI.
|
| 54 |
+
|
| 55 |
+
This License governs the use of the model (and its derivatives) and is informed by the model card associated with the model.
|
| 56 |
+
|
| 57 |
+
NOW THEREFORE, You and Licensor agree as follows:
|
| 58 |
+
|
| 59 |
+
1. Definitions
|
| 60 |
+
|
| 61 |
+
- "License" means the terms and conditions for use, reproduction, and Distribution as defined in this document.
|
| 62 |
+
- "Data" means a collection of information and/or content extracted from the dataset used with the Model, including to train, pretrain, or otherwise evaluate the Model. The Data is not licensed under this License.
|
| 63 |
+
- "Output" means the results of operating a Model as embodied in informational content resulting therefrom.
|
| 64 |
+
- "Model" means any accompanying machine-learning based assemblies (including checkpoints), consisting of learnt weights, parameters (including optimizer states), corresponding to the model architecture as embodied in the Complementary Material, that have been trained or tuned, in whole or in part on the Data, using the Complementary Material.
|
| 65 |
+
- "Derivatives of the Model" means all modifications to the Model, works based on the Model, or any other model which is created or initialized by transfer of patterns of the weights, parameters, activations or output of the Model, to the other model, in order to cause the other model to perform similarly to the Model, including - but not limited to - distillation methods entailing the use of intermediate data representations or methods based on the generation of synthetic data by the Model for training the other model.
|
| 66 |
+
- "Complementary Material" means the accompanying source code and scripts used to define, run, load, benchmark or evaluate the Model, and used to prepare data for training or evaluation, if any. This includes any accompanying documentation, tutorials, examples, etc, if any.
|
| 67 |
+
- "Distribution" means any transmission, reproduction, publication or other sharing of the Model or Derivatives of the Model to a third party, including providing the Model as a hosted service made available by electronic or other remote means - e.g. API-based or web access.
|
| 68 |
+
- "Licensor" means the copyright owner or entity authorized by the copyright owner that is granting the License, including the persons or entities that may have rights in the Model and/or distributing the Model.
|
| 69 |
+
- "You" (or "Your") means an individual or Legal Entity exercising permissions granted by this License and/or making use of the Model for whichever purpose and in any field of use, including usage of the Model in an end-use application - e.g. chatbot, translator, image generator.
|
| 70 |
+
- "Third Parties" means individuals or legal entities that are not under common control with Licensor or You.
|
| 71 |
+
- "Contribution" means any work of authorship, including the original version of the Model and any modifications or additions to that Model or Derivatives of the Model thereof, that is intentionally submitted to Licensor for inclusion in the Model by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Model, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
|
| 72 |
+
- "Contributor" means Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Model.
|
| 73 |
+
|
| 74 |
+
Section II: INTELLECTUAL PROPERTY RIGHTS
|
| 75 |
+
|
| 76 |
+
Both copyright and patent grants apply to the Model, Derivatives of the Model and Complementary Material. The Model and Derivatives of the Model are subject to additional terms as described in Section III.
|
| 77 |
+
|
| 78 |
+
2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly display, publicly perform, sublicense, and distribute the Complementary Material, the Model, and Derivatives of the Model.
|
| 79 |
+
3. Grant of Patent License. Subject to the terms and conditions of this License and where and as applicable, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this paragraph) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Model and the Complementary Material, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Model to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Model and/or Complementary Material or a Contribution incorporated within the Model and/or Complementary Material constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for the Model and/or Work shall terminate as of the date such litigation is asserted or filed.
|
| 80 |
+
|
| 81 |
+
Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION
|
| 82 |
+
|
| 83 |
+
4. Distribution and Redistribution. You may host for Third Party remote access purposes (e.g. software-as-a-service), reproduce and distribute copies of the Model or Derivatives of the Model thereof in any medium, with or without modifications, provided that You meet the following conditions:
|
| 84 |
+
Use-based restrictions as referenced in paragraph 5 MUST be included as an enforceable provision by You in any type of legal agreement (e.g. a license) governing the use and/or distribution of the Model or Derivatives of the Model, and You shall give notice to subsequent users You Distribute to, that the Model or Derivatives of the Model are subject to paragraph 5. This provision does not apply to the use of Complementary Material.
|
| 85 |
+
You must give any Third Party recipients of the Model or Derivatives of the Model a copy of this License;
|
| 86 |
+
You must cause any modified files to carry prominent notices stating that You changed the files;
|
| 87 |
+
You must retain all copyright, patent, trademark, and attribution notices excluding those notices that do not pertain to any part of the Model, Derivatives of the Model.
|
| 88 |
+
You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions - respecting paragraph 4.a. - for use, reproduction, or Distribution of Your modifications, or for any such Derivatives of the Model as a whole, provided Your use, reproduction, and Distribution of the Model otherwise complies with the conditions stated in this License.
|
| 89 |
+
5. Use-based restrictions. The restrictions set forth in Attachment A are considered Use-based restrictions. Therefore You cannot use the Model and the Derivatives of the Model for the specified restricted uses. You may use the Model subject to this License, including only for lawful purposes and in accordance with the License. Use may include creating any content with, finetuning, updating, running, training, evaluating and/or reparametrizing the Model. You shall require all of Your users who use the Model or a Derivative of the Model to comply with the terms of this paragraph (paragraph 5).
|
| 90 |
+
6. The Output You Generate. Except as set forth herein, Licensor claims no rights in the Output You generate using the Model. You are accountable for the Output you generate and its subsequent uses. No use of the output can contravene any provision as stated in the License.
|
| 91 |
+
|
| 92 |
+
Section IV: OTHER PROVISIONS
|
| 93 |
+
|
| 94 |
+
7. Updates and Runtime Restrictions. To the maximum extent permitted by law, Licensor reserves the right to restrict (remotely or otherwise) usage of the Model in violation of this License.
|
| 95 |
+
8. Trademarks and related. Nothing in this License permits You to make use of Licensors’ trademarks, trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between the parties; and any rights not expressly granted herein are reserved by the Licensors.
|
| 96 |
+
9. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Model and the Complementary Material (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Model, Derivatives of the Model, and the Complementary Material and assume any risks associated with Your exercise of permissions under this License.
|
| 97 |
+
10. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Model and the Complementary Material (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
|
| 98 |
+
11. Accepting Warranty or Additional Liability. While redistributing the Model, Derivatives of the Model and the Complementary Material thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
|
| 99 |
+
12. If any provision of this License is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein.
|
| 100 |
+
|
| 101 |
+
END OF TERMS AND CONDITIONS
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
Attachment A
|
| 107 |
+
|
| 108 |
+
Use Restrictions
|
| 109 |
+
|
| 110 |
+
You agree not to use the Model or Derivatives of the Model:
|
| 111 |
+
|
| 112 |
+
- In any way that violates any applicable national, federal, state, local or international law or regulation;
|
| 113 |
+
- For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
|
| 114 |
+
- To generate or disseminate verifiably false information and/or content with the purpose of harming others;
|
| 115 |
+
- To generate or disseminate personal identifiable information that can be used to harm an individual;
|
| 116 |
+
- To defame, disparage or otherwise harass others;
|
| 117 |
+
- For fully automated decision making that adversely impacts an individual’s legal rights or otherwise creates or modifies a binding, enforceable obligation;
|
| 118 |
+
- For any use intended to or which has the effect of discriminating against or harming individuals or groups based on online or offline social behavior or known or predicted personal or personality characteristics;
|
| 119 |
+
- To exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
|
| 120 |
+
- For any use intended to or which has the effect of discriminating against individuals or groups based on legally protected characteristics or categories;
|
| 121 |
+
- To provide medical advice and medical results interpretation;
|
| 122 |
+
- To generate or disseminate information for the purpose to be used for administration of justice, law enforcement, immigration or asylum processes, such as predicting an individual will commit fraud/crime commitment (e.g. by text profiling, drawing causal relationships between assertions made in documents, indiscriminate and arbitrarily-targeted use).
|
README.md
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
library_name: hunyuan3d-2
|
| 3 |
+
license: other
|
| 4 |
+
license_name: tencent-hunyuan-community
|
| 5 |
+
license_link: https://github.com/Tencent-Hunyuan/Hunyuan3D-2.1/blob/main/LICENSE
|
| 6 |
+
language:
|
| 7 |
+
- en
|
| 8 |
+
- zh
|
| 9 |
+
tags:
|
| 10 |
+
- image-to-3d
|
| 11 |
+
- text-to-3d
|
| 12 |
+
pipeline_tag: image-to-3d
|
| 13 |
+
extra_gated_eu_disallowed: true
|
| 14 |
+
---
|
| 15 |
+
|
| 16 |
+
<p align="center">
|
| 17 |
+
<img src="https://raw.githubusercontent.com/Tencent-Hunyuan/Hunyuan3D-2.1/refs/heads/main/assets/images/teaser.jpg">
|
| 18 |
+
</p>
|
| 19 |
+
|
| 20 |
+
<div align="center">
|
| 21 |
+
<a href=https://3d.hunyuan.tencent.com target="_blank"><img src=https://img.shields.io/badge/Hunyuan3D-black.svg?logo=homepage height=22px></a>
|
| 22 |
+
<a href=https://huggingface.co/spaces/tencent/Hunyuan3D-2.1 target="_blank"><img src=https://img.shields.io/badge/%F0%9F%A4%97%20Demo-276cb4.svg height=22px></a>
|
| 23 |
+
<a href=https://huggingface.co/tencent/Hunyuan3D-2.1 target="_blank"><img src=https://img.shields.io/badge/%F0%9F%A4%97%20Models-d96902.svg height=22px></a>
|
| 24 |
+
<a href=https://github.com/Tencent-Hunyuan/Hunyuan3D-2.1 target="_blank"><img src= https://img.shields.io/badge/Page-bb8a2e.svg?logo=github height=22px></a>
|
| 25 |
+
<a href=https://discord.gg/GuaWYwzKbX target="_blank"><img src= https://img.shields.io/badge/Discord-white.svg?logo=discord height=22px></a>
|
| 26 |
+
<a href=https://arxiv.org/abs/2506.15442 target="_blank"><img src=https://img.shields.io/badge/Report-b5212f.svg?logo=arxiv height=22px></a>
|
| 27 |
+
</div>
|
| 28 |
+
|
| 29 |
+
## 🔗 BibTeX
|
| 30 |
+
|
| 31 |
+
If you found this repository helpful, please cite our report:
|
| 32 |
+
|
| 33 |
+
```bibtex
|
| 34 |
+
@misc{hunyuan3d2025hunyuan3d,
|
| 35 |
+
title={Hunyuan3D 2.1: From Images to High-Fidelity 3D Assets with Production-Ready PBR Material},
|
| 36 |
+
author={Team Hunyuan3D and Shuhui Yang and Mingxin Yang and Yifei Feng and Xin Huang and Sheng Zhang and Zebin He and Di Luo and Haolin Liu and Yunfei Zhao and Qingxiang Lin and Zeqiang Lai and Xianghui Yang and Huiwen Shi and Zibo Zhao and Bowen Zhang and Hongyu Yan and Lifu Wang and Sicong Liu and Jihong Zhang and Meng Chen and Liang Dong and Yiwen Jia and Yulin Cai and Jiaao Yu and Yixuan Tang and Dongyuan Guo and Junlin Yu and Hao Zhang and Zheng Ye and Peng He and Runzhou Wu and Shida Wei and Chao Zhang and Yonghao Tan and Yifu Sun and Lin Niu and Shirui Huang and Bojian Zheng and Shu Liu and Shilin Chen and Xiang Yuan and Xiaofeng Yang and Kai Liu and Jianchen Zhu and Peng Chen and Tian Liu and Di Wang and Yuhong Liu and Linus and Jie Jiang and Jingwei Huang and Chunchao Guo},
|
| 37 |
+
year={2025},
|
| 38 |
+
eprint={2506.15442},
|
| 39 |
+
archivePrefix={arXiv},
|
| 40 |
+
primaryClass={cs.CV}
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
@misc{hunyuan3d22025tencent,
|
| 44 |
+
title={Hunyuan3D 2.0: Scaling Diffusion Models for High Resolution Textured 3D Assets Generation},
|
| 45 |
+
author={Tencent Hunyuan3D Team},
|
| 46 |
+
year={2025},
|
| 47 |
+
eprint={2501.12202},
|
| 48 |
+
archivePrefix={arXiv},
|
| 49 |
+
primaryClass={cs.CV}
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
@misc{yang2024tencent,
|
| 53 |
+
title={Tencent Hunyuan3D-1.0: A Unified Framework for Text-to-3D and Image-to-3D Generation},
|
| 54 |
+
author={Tencent Hunyuan3D Team},
|
| 55 |
+
year={2024},
|
| 56 |
+
eprint={2411.02293},
|
| 57 |
+
archivePrefix={arXiv},
|
| 58 |
+
primaryClass={cs.CV}
|
| 59 |
+
}
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
## Acknowledgements
|
| 65 |
+
|
| 66 |
+
We would like to thank the contributors to
|
| 67 |
+
the [TripoSG](https://github.com/VAST-AI-Research/TripoSG), [DINOv2](https://github.com/facebookresearch/dinov2), [Stable Diffusion](https://github.com/Stability-AI/stablediffusion), [FLUX](https://github.com/black-forest-labs/flux), [diffusers](https://github.com/huggingface/diffusers)
|
| 68 |
+
and [HuggingFace](https://huggingface.co) repositories, for their open research and exploration.
|
| 69 |
+
|
| 70 |
+
## Star History
|
| 71 |
+
|
| 72 |
+
<a href="https://star-history.com/#Tencent-Hunyuan/Hunyuan3D-2.1&Date">
|
| 73 |
+
<picture>
|
| 74 |
+
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=Tencent-Hunyuan/Hunyuan3D-2.1&type=Date&theme=dark" />
|
| 75 |
+
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=Tencent-Hunyuan/Hunyuan3D-2.1&type=Date" />
|
| 76 |
+
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=Tencent-Hunyuan/Hunyuan3D-2.1&type=Date" />
|
| 77 |
+
</picture>
|
| 78 |
+
</a>
|
demo.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
sys.path.insert(0, './hy3dshape')
|
| 3 |
+
sys.path.insert(0, './hy3dpaint')
|
| 4 |
+
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from hy3dshape.rembg import BackgroundRemover
|
| 7 |
+
from hy3dshape.pipelines import Hunyuan3DDiTFlowMatchingPipeline
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
from textureGenPipeline import Hunyuan3DPaintPipeline, Hunyuan3DPaintConfig
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
from torchvision_fix import apply_fix
|
| 14 |
+
apply_fix()
|
| 15 |
+
except ImportError:
|
| 16 |
+
print("Warning: torchvision_fix module not found, proceeding without compatibility fix")
|
| 17 |
+
except Exception as e:
|
| 18 |
+
print(f"Warning: Failed to apply torchvision fix: {e}")
|
| 19 |
+
|
| 20 |
+
# shape
|
| 21 |
+
model_path = 'tencent/Hunyuan3D-2.1'
|
| 22 |
+
pipeline_shapegen = Hunyuan3DDiTFlowMatchingPipeline.from_pretrained(model_path)
|
| 23 |
+
#
|
| 24 |
+
image_path = 'assets/demo.png'
|
| 25 |
+
image = Image.open(image_path).convert("RGBA")
|
| 26 |
+
if image.mode == 'RGB':
|
| 27 |
+
rembg = BackgroundRemover()
|
| 28 |
+
image = rembg(image)
|
| 29 |
+
|
| 30 |
+
mesh = pipeline_shapegen(image=image)[0]
|
| 31 |
+
mesh.export('demo.glb')
|
| 32 |
+
|
| 33 |
+
# paint
|
| 34 |
+
max_num_view = 6 # can be 6 to 9
|
| 35 |
+
resolution = 512 # can be 768 or 512
|
| 36 |
+
conf = Hunyuan3DPaintConfig(max_num_view, resolution)
|
| 37 |
+
conf.realesrgan_ckpt_path = "hy3dpaint/ckpt/RealESRGAN_x4plus.pth"
|
| 38 |
+
conf.multiview_cfg_path = "hy3dpaint/cfgs/hunyuan-paint-pbr.yaml"
|
| 39 |
+
conf.custom_pipeline = "hy3dpaint/hunyuanpaintpbr"
|
| 40 |
+
paint_pipeline = Hunyuan3DPaintPipeline(conf)
|
| 41 |
+
|
| 42 |
+
output_mesh_path = 'demo_textured.glb'
|
| 43 |
+
output_mesh_path = paint_pipeline(
|
| 44 |
+
mesh_path = "demo.glb",
|
| 45 |
+
image_path = 'assets/demo.png',
|
| 46 |
+
output_mesh_path = output_mesh_path
|
| 47 |
+
)
|
hunyuan3d-dit-v2-1/config.yaml
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
target: hy3dshape.models.denoisers.hunyuandit.HunYuanDiTPlain
|
| 3 |
+
params:
|
| 4 |
+
input_size: &num_latents 4096
|
| 5 |
+
in_channels: 64
|
| 6 |
+
hidden_size: 2048
|
| 7 |
+
context_dim: 1024
|
| 8 |
+
depth: 21
|
| 9 |
+
num_heads: 16
|
| 10 |
+
qk_norm: true
|
| 11 |
+
text_len: 1370
|
| 12 |
+
with_decoupled_ca: false
|
| 13 |
+
use_attention_pooling: false
|
| 14 |
+
qk_norm_type: 'rms'
|
| 15 |
+
qkv_bias: false
|
| 16 |
+
use_pos_emb: false
|
| 17 |
+
num_moe_layers: 6
|
| 18 |
+
num_experts: 8
|
| 19 |
+
moe_top_k: 2
|
| 20 |
+
|
| 21 |
+
vae:
|
| 22 |
+
target: hy3dshape.models.autoencoders.ShapeVAE
|
| 23 |
+
params:
|
| 24 |
+
num_latents: *num_latents
|
| 25 |
+
embed_dim: 64
|
| 26 |
+
num_freqs: 8
|
| 27 |
+
include_pi: false
|
| 28 |
+
heads: 16
|
| 29 |
+
width: 1024
|
| 30 |
+
num_encoder_layers: 8
|
| 31 |
+
num_decoder_layers: 16
|
| 32 |
+
qkv_bias: false
|
| 33 |
+
qk_norm: true
|
| 34 |
+
scale_factor: 1.0039506158752403
|
| 35 |
+
geo_decoder_mlp_expand_ratio: 4
|
| 36 |
+
geo_decoder_downsample_ratio: 1
|
| 37 |
+
geo_decoder_ln_post: true
|
| 38 |
+
point_feats: 4
|
| 39 |
+
pc_size: 81920
|
| 40 |
+
pc_sharpedge_size: 0
|
| 41 |
+
|
| 42 |
+
conditioner:
|
| 43 |
+
target: hy3dshape.models.conditioner.SingleImageEncoder
|
| 44 |
+
params:
|
| 45 |
+
main_image_encoder:
|
| 46 |
+
type: DinoImageEncoder # dino large
|
| 47 |
+
kwargs:
|
| 48 |
+
config:
|
| 49 |
+
attention_probs_dropout_prob: 0.0
|
| 50 |
+
drop_path_rate: 0.0
|
| 51 |
+
hidden_act: gelu
|
| 52 |
+
hidden_dropout_prob: 0.0
|
| 53 |
+
hidden_size: 1024
|
| 54 |
+
image_size: 518
|
| 55 |
+
initializer_range: 0.02
|
| 56 |
+
layer_norm_eps: 1.e-6
|
| 57 |
+
layerscale_value: 1.0
|
| 58 |
+
mlp_ratio: 4
|
| 59 |
+
model_type: dinov2
|
| 60 |
+
num_attention_heads: 16
|
| 61 |
+
num_channels: 3
|
| 62 |
+
num_hidden_layers: 24
|
| 63 |
+
patch_size: 14
|
| 64 |
+
qkv_bias: true
|
| 65 |
+
torch_dtype: float32
|
| 66 |
+
use_swiglu_ffn: false
|
| 67 |
+
image_size: 518
|
| 68 |
+
use_cls_token: true
|
| 69 |
+
|
| 70 |
+
scheduler:
|
| 71 |
+
target: hy3dshape.schedulers.FlowMatchEulerDiscreteScheduler
|
| 72 |
+
params:
|
| 73 |
+
num_train_timesteps: 1000
|
| 74 |
+
|
| 75 |
+
image_processor:
|
| 76 |
+
target: hy3dshape.preprocessors.ImageProcessorV2
|
| 77 |
+
params:
|
| 78 |
+
size: 512
|
| 79 |
+
border_ratio: 0.15
|
| 80 |
+
|
| 81 |
+
pipeline:
|
| 82 |
+
target: hy3dshape.pipelines.Hunyuan3DDiTFlowMatchingPipeline
|
hunyuan3d-dit-v2-1/model.fp16.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6b519fc7242f78e9b5f47ea4d55668fe3d944a2d27332f4ca68d29a6ff603f5e
|
| 3 |
+
size 7366389768
|
hunyuan3d-paintpbr-v2-1/README.md
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: openrail++
|
| 3 |
+
tags:
|
| 4 |
+
- stable-diffusion
|
| 5 |
+
- text-to-image
|
| 6 |
+
---
|
| 7 |
+
|
| 8 |
+
# SD v2.1-base with Zero Terminal SNR (LAION Aesthetic 6+)
|
| 9 |
+
|
| 10 |
+
This model is used in [Diffusion Model with Perceptual Loss](https://arxiv.org/abs/2401.00110) paper as the MSE baseline.
|
| 11 |
+
|
| 12 |
+
This model is trained using zero terminal SNR schedule following [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) paper on LAION aesthetic 6+ data.
|
| 13 |
+
|
| 14 |
+
This model is finetuned from [stabilityai/stable-diffusion-2-1-base](https://huggingface.co/stabilityai/stable-diffusion-2-1-base).
|
| 15 |
+
|
| 16 |
+
This model is meant for research demonstration, not for production use.
|
| 17 |
+
|
| 18 |
+
## Usage
|
| 19 |
+
|
| 20 |
+
```python
|
| 21 |
+
from diffusers import StableDiffusionPipeline
|
| 22 |
+
prompt = "A young girl smiling"
|
| 23 |
+
pipe = StableDiffusionPipeline.from_pretrained("ByteDance/sd2.1-base-zsnr-laionaes6").to("cuda")
|
| 24 |
+
pipe(prompt, guidance_scale=7.5, guidance_rescale=0.7).images[0].save("out.jpg")
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
## Related Models
|
| 28 |
+
|
| 29 |
+
* [bytedance/sd2.1-base-zsnr-laionaes5](https://huggingface.co/ByteDance/sd2.1-base-zsnr-laionaes5)
|
| 30 |
+
* [bytedance/sd2.1-base-zsnr-laionaes6](https://huggingface.co/ByteDance/sd2.1-base-zsnr-laionaes6)
|
| 31 |
+
* [bytedance/sd2.1-base-zsnr-laionaes6-perceptual](https://huggingface.co/ByteDance/sd2.1-base-zsnr-laionaes6-perceptual)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
## Cite as
|
| 35 |
+
```
|
| 36 |
+
@misc{lin2024diffusion,
|
| 37 |
+
title={Diffusion Model with Perceptual Loss},
|
| 38 |
+
author={Shanchuan Lin and Xiao Yang},
|
| 39 |
+
year={2024},
|
| 40 |
+
eprint={2401.00110},
|
| 41 |
+
archivePrefix={arXiv},
|
| 42 |
+
primaryClass={cs.CV}
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
@misc{lin2023common,
|
| 46 |
+
title={Common Diffusion Noise Schedules and Sample Steps are Flawed},
|
| 47 |
+
author={Shanchuan Lin and Bingchen Liu and Jiashi Li and Xiao Yang},
|
| 48 |
+
year={2023},
|
| 49 |
+
eprint={2305.08891},
|
| 50 |
+
archivePrefix={arXiv},
|
| 51 |
+
primaryClass={cs.CV}
|
| 52 |
+
}
|
| 53 |
+
```
|
hunyuan3d-paintpbr-v2-1/feature_extractor/preprocessor_config.json
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"crop_size": 224,
|
| 3 |
+
"do_center_crop": true,
|
| 4 |
+
"do_convert_rgb": true,
|
| 5 |
+
"do_normalize": true,
|
| 6 |
+
"do_resize": true,
|
| 7 |
+
"feature_extractor_type": "CLIPFeatureExtractor",
|
| 8 |
+
"image_mean": [
|
| 9 |
+
0.48145466,
|
| 10 |
+
0.4578275,
|
| 11 |
+
0.40821073
|
| 12 |
+
],
|
| 13 |
+
"image_std": [
|
| 14 |
+
0.26862954,
|
| 15 |
+
0.26130258,
|
| 16 |
+
0.27577711
|
| 17 |
+
],
|
| 18 |
+
"resample": 3,
|
| 19 |
+
"size": 224
|
| 20 |
+
}
|
hunyuan3d-paintpbr-v2-1/image_encoder/config.json
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_name_or_path": "vision_encoder",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"CLIPVisionModelWithProjection"
|
| 5 |
+
],
|
| 6 |
+
"attention_dropout": 0.0,
|
| 7 |
+
"dropout": 0.0,
|
| 8 |
+
"hidden_act": "gelu",
|
| 9 |
+
"hidden_size": 1280,
|
| 10 |
+
"image_size": 224,
|
| 11 |
+
"initializer_factor": 1.0,
|
| 12 |
+
"initializer_range": 0.02,
|
| 13 |
+
"intermediate_size": 5120,
|
| 14 |
+
"layer_norm_eps": 1e-05,
|
| 15 |
+
"model_type": "clip_vision_model",
|
| 16 |
+
"num_attention_heads": 16,
|
| 17 |
+
"num_channels": 3,
|
| 18 |
+
"num_hidden_layers": 32,
|
| 19 |
+
"patch_size": 14,
|
| 20 |
+
"projection_dim": 1024,
|
| 21 |
+
"torch_dtype": "float16",
|
| 22 |
+
"transformers_version": "4.36.0"
|
| 23 |
+
}
|
hunyuan3d-paintpbr-v2-1/image_encoder/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ae616c24393dd1854372b0639e5541666f7521cbe219669255e865cb7f89466a
|
| 3 |
+
size 1264217240
|
hunyuan3d-paintpbr-v2-1/model_index.json
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "HunyuanPaintPipeline",
|
| 3 |
+
"_diffusers_version": "0.24.0",
|
| 4 |
+
"feature_extractor": [
|
| 5 |
+
"transformers",
|
| 6 |
+
"CLIPImageProcessor"
|
| 7 |
+
],
|
| 8 |
+
"requires_safety_checker": false,
|
| 9 |
+
"safety_checker": [
|
| 10 |
+
null,
|
| 11 |
+
null
|
| 12 |
+
],
|
| 13 |
+
"scheduler": [
|
| 14 |
+
"diffusers",
|
| 15 |
+
"DDIMScheduler"
|
| 16 |
+
],
|
| 17 |
+
"text_encoder": [
|
| 18 |
+
"transformers",
|
| 19 |
+
"CLIPTextModel"
|
| 20 |
+
],
|
| 21 |
+
"tokenizer": [
|
| 22 |
+
"transformers",
|
| 23 |
+
"CLIPTokenizer"
|
| 24 |
+
],
|
| 25 |
+
"unet": [
|
| 26 |
+
"modules",
|
| 27 |
+
"UNet2p5DConditionModel"
|
| 28 |
+
],
|
| 29 |
+
"vae": [
|
| 30 |
+
"diffusers",
|
| 31 |
+
"AutoencoderKL"
|
| 32 |
+
],
|
| 33 |
+
"image_encoder": [
|
| 34 |
+
"transformers",
|
| 35 |
+
"CLIPVisionModelWithProjection"
|
| 36 |
+
]
|
| 37 |
+
}
|
hunyuan3d-paintpbr-v2-1/scheduler/scheduler_config.json
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "DDIMScheduler",
|
| 3 |
+
"_diffusers_version": "0.23.1",
|
| 4 |
+
"beta_end": 0.012,
|
| 5 |
+
"beta_schedule": "scaled_linear",
|
| 6 |
+
"beta_start": 0.00085,
|
| 7 |
+
"clip_sample": false,
|
| 8 |
+
"num_train_timesteps": 1000,
|
| 9 |
+
"prediction_type": "v_prediction",
|
| 10 |
+
"set_alpha_to_one": true,
|
| 11 |
+
"steps_offset": 1,
|
| 12 |
+
"trained_betas": null,
|
| 13 |
+
"timestep_spacing": "trailing",
|
| 14 |
+
"rescale_betas_zero_snr": true
|
| 15 |
+
}
|
hunyuan3d-paintpbr-v2-1/text_encoder/config.json
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_name_or_path": "stabilityai/stable-diffusion-2",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"CLIPTextModel"
|
| 5 |
+
],
|
| 6 |
+
"attention_dropout": 0.0,
|
| 7 |
+
"bos_token_id": 0,
|
| 8 |
+
"dropout": 0.0,
|
| 9 |
+
"eos_token_id": 2,
|
| 10 |
+
"hidden_act": "gelu",
|
| 11 |
+
"hidden_size": 1024,
|
| 12 |
+
"initializer_factor": 1.0,
|
| 13 |
+
"initializer_range": 0.02,
|
| 14 |
+
"intermediate_size": 4096,
|
| 15 |
+
"layer_norm_eps": 1e-05,
|
| 16 |
+
"max_position_embeddings": 77,
|
| 17 |
+
"model_type": "clip_text_model",
|
| 18 |
+
"num_attention_heads": 16,
|
| 19 |
+
"num_hidden_layers": 23,
|
| 20 |
+
"pad_token_id": 1,
|
| 21 |
+
"projection_dim": 512,
|
| 22 |
+
"torch_dtype": "float32",
|
| 23 |
+
"transformers_version": "4.25.0.dev0",
|
| 24 |
+
"vocab_size": 49408
|
| 25 |
+
}
|
hunyuan3d-paintpbr-v2-1/text_encoder/pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c3e254d7b61353497ea0be2c4013df4ea8f739ee88cffa0ba58cd085459ed565
|
| 3 |
+
size 1361671895
|
hunyuan3d-paintpbr-v2-1/tokenizer/merges.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
hunyuan3d-paintpbr-v2-1/tokenizer/special_tokens_map.json
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": {
|
| 3 |
+
"content": "<|startoftext|>",
|
| 4 |
+
"lstrip": false,
|
| 5 |
+
"normalized": true,
|
| 6 |
+
"rstrip": false,
|
| 7 |
+
"single_word": false
|
| 8 |
+
},
|
| 9 |
+
"eos_token": {
|
| 10 |
+
"content": "<|endoftext|>",
|
| 11 |
+
"lstrip": false,
|
| 12 |
+
"normalized": true,
|
| 13 |
+
"rstrip": false,
|
| 14 |
+
"single_word": false
|
| 15 |
+
},
|
| 16 |
+
"pad_token": "!",
|
| 17 |
+
"unk_token": {
|
| 18 |
+
"content": "<|endoftext|>",
|
| 19 |
+
"lstrip": false,
|
| 20 |
+
"normalized": true,
|
| 21 |
+
"rstrip": false,
|
| 22 |
+
"single_word": false
|
| 23 |
+
}
|
| 24 |
+
}
|
hunyuan3d-paintpbr-v2-1/tokenizer/tokenizer_config.json
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_prefix_space": false,
|
| 3 |
+
"bos_token": {
|
| 4 |
+
"__type": "AddedToken",
|
| 5 |
+
"content": "<|startoftext|>",
|
| 6 |
+
"lstrip": false,
|
| 7 |
+
"normalized": true,
|
| 8 |
+
"rstrip": false,
|
| 9 |
+
"single_word": false
|
| 10 |
+
},
|
| 11 |
+
"do_lower_case": true,
|
| 12 |
+
"eos_token": {
|
| 13 |
+
"__type": "AddedToken",
|
| 14 |
+
"content": "<|endoftext|>",
|
| 15 |
+
"lstrip": false,
|
| 16 |
+
"normalized": true,
|
| 17 |
+
"rstrip": false,
|
| 18 |
+
"single_word": false
|
| 19 |
+
},
|
| 20 |
+
"errors": "replace",
|
| 21 |
+
"model_max_length": 77,
|
| 22 |
+
"name_or_path": "stabilityai/stable-diffusion-2",
|
| 23 |
+
"pad_token": "<|endoftext|>",
|
| 24 |
+
"special_tokens_map_file": "./special_tokens_map.json",
|
| 25 |
+
"tokenizer_class": "CLIPTokenizer",
|
| 26 |
+
"unk_token": {
|
| 27 |
+
"__type": "AddedToken",
|
| 28 |
+
"content": "<|endoftext|>",
|
| 29 |
+
"lstrip": false,
|
| 30 |
+
"normalized": true,
|
| 31 |
+
"rstrip": false,
|
| 32 |
+
"single_word": false
|
| 33 |
+
}
|
| 34 |
+
}
|
hunyuan3d-paintpbr-v2-1/tokenizer/vocab.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
hunyuan3d-paintpbr-v2-1/unet/attn_processor.py
ADDED
|
@@ -0,0 +1,839 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
| 2 |
+
# except for the third-party components listed below.
|
| 3 |
+
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
| 4 |
+
# in the repsective licenses of these third-party components.
|
| 5 |
+
# Users must comply with all terms and conditions of original licenses of these third-party
|
| 6 |
+
# components and must ensure that the usage of the third party components adheres to
|
| 7 |
+
# all relevant laws and regulations.
|
| 8 |
+
|
| 9 |
+
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
| 10 |
+
# their software and algorithms, including trained model weights, parameters (including
|
| 11 |
+
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
| 12 |
+
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
| 13 |
+
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
from typing import Optional, Dict, Tuple, Union, Literal, List, Callable
|
| 19 |
+
from einops import rearrange
|
| 20 |
+
from diffusers.utils import deprecate
|
| 21 |
+
from diffusers.models.attention_processor import Attention, AttnProcessor
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class AttnUtils:
|
| 25 |
+
"""
|
| 26 |
+
Shared utility functions for attention processing.
|
| 27 |
+
|
| 28 |
+
This class provides common operations used across different attention processors
|
| 29 |
+
to eliminate code duplication and improve maintainability.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
@staticmethod
|
| 33 |
+
def check_pytorch_compatibility():
|
| 34 |
+
"""
|
| 35 |
+
Check PyTorch compatibility for scaled_dot_product_attention.
|
| 36 |
+
|
| 37 |
+
Raises:
|
| 38 |
+
ImportError: If PyTorch version doesn't support scaled_dot_product_attention
|
| 39 |
+
"""
|
| 40 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 41 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
| 42 |
+
|
| 43 |
+
@staticmethod
|
| 44 |
+
def handle_deprecation_warning(args, kwargs):
|
| 45 |
+
"""
|
| 46 |
+
Handle deprecation warning for the 'scale' argument.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
args: Positional arguments passed to attention processor
|
| 50 |
+
kwargs: Keyword arguments passed to attention processor
|
| 51 |
+
"""
|
| 52 |
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
| 53 |
+
deprecation_message = (
|
| 54 |
+
"The `scale` argument is deprecated and will be ignored."
|
| 55 |
+
"Please remove it, as passing it will raise an error in the future."
|
| 56 |
+
"`scale` should directly be passed while calling the underlying pipeline component"
|
| 57 |
+
"i.e., via `cross_attention_kwargs`."
|
| 58 |
+
)
|
| 59 |
+
deprecate("scale", "1.0.0", deprecation_message)
|
| 60 |
+
|
| 61 |
+
@staticmethod
|
| 62 |
+
def prepare_hidden_states(
|
| 63 |
+
hidden_states, attn, temb, spatial_norm_attr="spatial_norm", group_norm_attr="group_norm"
|
| 64 |
+
):
|
| 65 |
+
"""
|
| 66 |
+
Common preprocessing of hidden states for attention computation.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
hidden_states: Input hidden states tensor
|
| 70 |
+
attn: Attention module instance
|
| 71 |
+
temb: Optional temporal embedding tensor
|
| 72 |
+
spatial_norm_attr: Attribute name for spatial normalization
|
| 73 |
+
group_norm_attr: Attribute name for group normalization
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
Tuple of (processed_hidden_states, residual, input_ndim, shape_info)
|
| 77 |
+
"""
|
| 78 |
+
residual = hidden_states
|
| 79 |
+
|
| 80 |
+
spatial_norm = getattr(attn, spatial_norm_attr, None)
|
| 81 |
+
if spatial_norm is not None:
|
| 82 |
+
hidden_states = spatial_norm(hidden_states, temb)
|
| 83 |
+
|
| 84 |
+
input_ndim = hidden_states.ndim
|
| 85 |
+
|
| 86 |
+
if input_ndim == 4:
|
| 87 |
+
batch_size, channel, height, width = hidden_states.shape
|
| 88 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
| 89 |
+
else:
|
| 90 |
+
batch_size, channel, height, width = None, None, None, None
|
| 91 |
+
|
| 92 |
+
group_norm = getattr(attn, group_norm_attr, None)
|
| 93 |
+
if group_norm is not None:
|
| 94 |
+
hidden_states = group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
| 95 |
+
|
| 96 |
+
return hidden_states, residual, input_ndim, (batch_size, channel, height, width)
|
| 97 |
+
|
| 98 |
+
@staticmethod
|
| 99 |
+
def prepare_attention_mask(attention_mask, attn, sequence_length, batch_size):
|
| 100 |
+
"""
|
| 101 |
+
Prepare attention mask for scaled_dot_product_attention.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
attention_mask: Input attention mask tensor or None
|
| 105 |
+
attn: Attention module instance
|
| 106 |
+
sequence_length: Length of the sequence
|
| 107 |
+
batch_size: Batch size
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
Prepared attention mask tensor reshaped for multi-head attention
|
| 111 |
+
"""
|
| 112 |
+
if attention_mask is not None:
|
| 113 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
| 114 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
| 115 |
+
return attention_mask
|
| 116 |
+
|
| 117 |
+
@staticmethod
|
| 118 |
+
def reshape_qkv_for_attention(tensor, batch_size, attn_heads, head_dim):
|
| 119 |
+
"""
|
| 120 |
+
Reshape Q/K/V tensors for multi-head attention computation.
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
tensor: Input tensor to reshape
|
| 124 |
+
batch_size: Batch size
|
| 125 |
+
attn_heads: Number of attention heads
|
| 126 |
+
head_dim: Dimension per attention head
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
Reshaped tensor with shape [batch_size, attn_heads, seq_len, head_dim]
|
| 130 |
+
"""
|
| 131 |
+
return tensor.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
|
| 132 |
+
|
| 133 |
+
@staticmethod
|
| 134 |
+
def apply_norms(query, key, norm_q, norm_k):
|
| 135 |
+
"""
|
| 136 |
+
Apply Q/K normalization layers if available.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
query: Query tensor
|
| 140 |
+
key: Key tensor
|
| 141 |
+
norm_q: Query normalization layer (optional)
|
| 142 |
+
norm_k: Key normalization layer (optional)
|
| 143 |
+
|
| 144 |
+
Returns:
|
| 145 |
+
Tuple of (normalized_query, normalized_key)
|
| 146 |
+
"""
|
| 147 |
+
if norm_q is not None:
|
| 148 |
+
query = norm_q(query)
|
| 149 |
+
if norm_k is not None:
|
| 150 |
+
key = norm_k(key)
|
| 151 |
+
return query, key
|
| 152 |
+
|
| 153 |
+
@staticmethod
|
| 154 |
+
def finalize_output(hidden_states, input_ndim, shape_info, attn, residual, to_out):
|
| 155 |
+
"""
|
| 156 |
+
Common output processing including projection, dropout, reshaping, and residual connection.
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
hidden_states: Processed hidden states from attention
|
| 160 |
+
input_ndim: Original input tensor dimensions
|
| 161 |
+
shape_info: Tuple containing original shape information
|
| 162 |
+
attn: Attention module instance
|
| 163 |
+
residual: Residual connection tensor
|
| 164 |
+
to_out: Output projection layers [linear, dropout]
|
| 165 |
+
|
| 166 |
+
Returns:
|
| 167 |
+
Final output tensor after all processing steps
|
| 168 |
+
"""
|
| 169 |
+
batch_size, channel, height, width = shape_info
|
| 170 |
+
|
| 171 |
+
# Apply output projection and dropout
|
| 172 |
+
hidden_states = to_out[0](hidden_states)
|
| 173 |
+
hidden_states = to_out[1](hidden_states)
|
| 174 |
+
|
| 175 |
+
# Reshape back if needed
|
| 176 |
+
if input_ndim == 4:
|
| 177 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
| 178 |
+
|
| 179 |
+
# Apply residual connection
|
| 180 |
+
if attn.residual_connection:
|
| 181 |
+
hidden_states = hidden_states + residual
|
| 182 |
+
|
| 183 |
+
# Apply rescaling
|
| 184 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
| 185 |
+
return hidden_states
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
# Base class for attention processors (eliminating initialization duplication)
|
| 189 |
+
class BaseAttnProcessor(nn.Module):
|
| 190 |
+
"""
|
| 191 |
+
Base class for attention processors with common initialization.
|
| 192 |
+
|
| 193 |
+
This base class provides shared parameter initialization and module registration
|
| 194 |
+
functionality to reduce code duplication across different attention processor types.
|
| 195 |
+
"""
|
| 196 |
+
|
| 197 |
+
def __init__(
|
| 198 |
+
self,
|
| 199 |
+
query_dim: int,
|
| 200 |
+
pbr_setting: List[str] = ["albedo", "mr"],
|
| 201 |
+
cross_attention_dim: Optional[int] = None,
|
| 202 |
+
heads: int = 8,
|
| 203 |
+
kv_heads: Optional[int] = None,
|
| 204 |
+
dim_head: int = 64,
|
| 205 |
+
dropout: float = 0.0,
|
| 206 |
+
bias: bool = False,
|
| 207 |
+
upcast_attention: bool = False,
|
| 208 |
+
upcast_softmax: bool = False,
|
| 209 |
+
cross_attention_norm: Optional[str] = None,
|
| 210 |
+
cross_attention_norm_num_groups: int = 32,
|
| 211 |
+
qk_norm: Optional[str] = None,
|
| 212 |
+
added_kv_proj_dim: Optional[int] = None,
|
| 213 |
+
added_proj_bias: Optional[bool] = True,
|
| 214 |
+
norm_num_groups: Optional[int] = None,
|
| 215 |
+
spatial_norm_dim: Optional[int] = None,
|
| 216 |
+
out_bias: bool = True,
|
| 217 |
+
scale_qk: bool = True,
|
| 218 |
+
only_cross_attention: bool = False,
|
| 219 |
+
eps: float = 1e-5,
|
| 220 |
+
rescale_output_factor: float = 1.0,
|
| 221 |
+
residual_connection: bool = False,
|
| 222 |
+
_from_deprecated_attn_block: bool = False,
|
| 223 |
+
processor: Optional["AttnProcessor"] = None,
|
| 224 |
+
out_dim: int = None,
|
| 225 |
+
out_context_dim: int = None,
|
| 226 |
+
context_pre_only=None,
|
| 227 |
+
pre_only=False,
|
| 228 |
+
elementwise_affine: bool = True,
|
| 229 |
+
is_causal: bool = False,
|
| 230 |
+
**kwargs,
|
| 231 |
+
):
|
| 232 |
+
"""
|
| 233 |
+
Initialize base attention processor with common parameters.
|
| 234 |
+
|
| 235 |
+
Args:
|
| 236 |
+
query_dim: Dimension of query features
|
| 237 |
+
pbr_setting: List of PBR material types to process (e.g., ["albedo", "mr"])
|
| 238 |
+
cross_attention_dim: Dimension of cross-attention features (optional)
|
| 239 |
+
heads: Number of attention heads
|
| 240 |
+
kv_heads: Number of key-value heads for grouped query attention (optional)
|
| 241 |
+
dim_head: Dimension per attention head
|
| 242 |
+
dropout: Dropout rate
|
| 243 |
+
bias: Whether to use bias in linear projections
|
| 244 |
+
upcast_attention: Whether to upcast attention computation to float32
|
| 245 |
+
upcast_softmax: Whether to upcast softmax computation to float32
|
| 246 |
+
cross_attention_norm: Type of cross-attention normalization (optional)
|
| 247 |
+
cross_attention_norm_num_groups: Number of groups for cross-attention norm
|
| 248 |
+
qk_norm: Type of query-key normalization (optional)
|
| 249 |
+
added_kv_proj_dim: Dimension for additional key-value projections (optional)
|
| 250 |
+
added_proj_bias: Whether to use bias in additional projections
|
| 251 |
+
norm_num_groups: Number of groups for normalization (optional)
|
| 252 |
+
spatial_norm_dim: Dimension for spatial normalization (optional)
|
| 253 |
+
out_bias: Whether to use bias in output projection
|
| 254 |
+
scale_qk: Whether to scale query-key products
|
| 255 |
+
only_cross_attention: Whether to only perform cross-attention
|
| 256 |
+
eps: Small epsilon value for numerical stability
|
| 257 |
+
rescale_output_factor: Factor to rescale output values
|
| 258 |
+
residual_connection: Whether to use residual connections
|
| 259 |
+
_from_deprecated_attn_block: Flag for deprecated attention blocks
|
| 260 |
+
processor: Optional attention processor instance
|
| 261 |
+
out_dim: Output dimension (optional)
|
| 262 |
+
out_context_dim: Output context dimension (optional)
|
| 263 |
+
context_pre_only: Whether to only process context in pre-processing
|
| 264 |
+
pre_only: Whether to only perform pre-processing
|
| 265 |
+
elementwise_affine: Whether to use element-wise affine transformations
|
| 266 |
+
is_causal: Whether to use causal attention masking
|
| 267 |
+
**kwargs: Additional keyword arguments
|
| 268 |
+
"""
|
| 269 |
+
super().__init__()
|
| 270 |
+
AttnUtils.check_pytorch_compatibility()
|
| 271 |
+
|
| 272 |
+
# Store common attributes
|
| 273 |
+
self.pbr_setting = pbr_setting
|
| 274 |
+
self.n_pbr_tokens = len(self.pbr_setting)
|
| 275 |
+
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
| 276 |
+
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
|
| 277 |
+
self.query_dim = query_dim
|
| 278 |
+
self.use_bias = bias
|
| 279 |
+
self.is_cross_attention = cross_attention_dim is not None
|
| 280 |
+
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
| 281 |
+
self.upcast_attention = upcast_attention
|
| 282 |
+
self.upcast_softmax = upcast_softmax
|
| 283 |
+
self.rescale_output_factor = rescale_output_factor
|
| 284 |
+
self.residual_connection = residual_connection
|
| 285 |
+
self.dropout = dropout
|
| 286 |
+
self.fused_projections = False
|
| 287 |
+
self.out_dim = out_dim if out_dim is not None else query_dim
|
| 288 |
+
self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim
|
| 289 |
+
self.context_pre_only = context_pre_only
|
| 290 |
+
self.pre_only = pre_only
|
| 291 |
+
self.is_causal = is_causal
|
| 292 |
+
self._from_deprecated_attn_block = _from_deprecated_attn_block
|
| 293 |
+
self.scale_qk = scale_qk
|
| 294 |
+
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
|
| 295 |
+
self.heads = out_dim // dim_head if out_dim is not None else heads
|
| 296 |
+
self.sliceable_head_dim = heads
|
| 297 |
+
self.added_kv_proj_dim = added_kv_proj_dim
|
| 298 |
+
self.only_cross_attention = only_cross_attention
|
| 299 |
+
self.added_proj_bias = added_proj_bias
|
| 300 |
+
|
| 301 |
+
# Validation
|
| 302 |
+
if self.added_kv_proj_dim is None and self.only_cross_attention:
|
| 303 |
+
raise ValueError(
|
| 304 |
+
"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None."
|
| 305 |
+
"Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
def register_pbr_modules(self, module_types: List[str], **kwargs):
|
| 309 |
+
"""
|
| 310 |
+
Generic PBR module registration to eliminate code repetition.
|
| 311 |
+
|
| 312 |
+
Dynamically registers PyTorch modules for different PBR material types
|
| 313 |
+
based on the specified module types and PBR settings.
|
| 314 |
+
|
| 315 |
+
Args:
|
| 316 |
+
module_types: List of module types to register ("qkv", "v_only", "out", "add_kv")
|
| 317 |
+
**kwargs: Additional arguments for module configuration
|
| 318 |
+
"""
|
| 319 |
+
for pbr_token in self.pbr_setting:
|
| 320 |
+
if pbr_token == "albedo":
|
| 321 |
+
continue
|
| 322 |
+
|
| 323 |
+
for module_type in module_types:
|
| 324 |
+
if module_type == "qkv":
|
| 325 |
+
self.register_module(
|
| 326 |
+
f"to_q_{pbr_token}", nn.Linear(self.query_dim, self.inner_dim, bias=self.use_bias)
|
| 327 |
+
)
|
| 328 |
+
self.register_module(
|
| 329 |
+
f"to_k_{pbr_token}", nn.Linear(self.cross_attention_dim, self.inner_dim, bias=self.use_bias)
|
| 330 |
+
)
|
| 331 |
+
self.register_module(
|
| 332 |
+
f"to_v_{pbr_token}", nn.Linear(self.cross_attention_dim, self.inner_dim, bias=self.use_bias)
|
| 333 |
+
)
|
| 334 |
+
elif module_type == "v_only":
|
| 335 |
+
self.register_module(
|
| 336 |
+
f"to_v_{pbr_token}", nn.Linear(self.cross_attention_dim, self.inner_dim, bias=self.use_bias)
|
| 337 |
+
)
|
| 338 |
+
elif module_type == "out":
|
| 339 |
+
if not self.pre_only:
|
| 340 |
+
self.register_module(
|
| 341 |
+
f"to_out_{pbr_token}",
|
| 342 |
+
nn.ModuleList(
|
| 343 |
+
[
|
| 344 |
+
nn.Linear(self.inner_dim, self.out_dim, bias=kwargs.get("out_bias", True)),
|
| 345 |
+
nn.Dropout(self.dropout),
|
| 346 |
+
]
|
| 347 |
+
),
|
| 348 |
+
)
|
| 349 |
+
else:
|
| 350 |
+
self.register_module(f"to_out_{pbr_token}", None)
|
| 351 |
+
elif module_type == "add_kv":
|
| 352 |
+
if self.added_kv_proj_dim is not None:
|
| 353 |
+
self.register_module(
|
| 354 |
+
f"add_k_proj_{pbr_token}",
|
| 355 |
+
nn.Linear(self.added_kv_proj_dim, self.inner_kv_dim, bias=self.added_proj_bias),
|
| 356 |
+
)
|
| 357 |
+
self.register_module(
|
| 358 |
+
f"add_v_proj_{pbr_token}",
|
| 359 |
+
nn.Linear(self.added_kv_proj_dim, self.inner_kv_dim, bias=self.added_proj_bias),
|
| 360 |
+
)
|
| 361 |
+
else:
|
| 362 |
+
self.register_module(f"add_k_proj_{pbr_token}", None)
|
| 363 |
+
self.register_module(f"add_v_proj_{pbr_token}", None)
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
# Rotary Position Embedding utilities (specialized for PoseRoPE)
|
| 367 |
+
class RotaryEmbedding:
|
| 368 |
+
"""
|
| 369 |
+
Rotary position embedding utilities for 3D spatial attention.
|
| 370 |
+
|
| 371 |
+
Provides functions to compute and apply rotary position embeddings (RoPE)
|
| 372 |
+
for 1D, 3D spatial coordinates used in 3D-aware attention mechanisms.
|
| 373 |
+
"""
|
| 374 |
+
|
| 375 |
+
@staticmethod
|
| 376 |
+
def get_1d_rotary_pos_embed(dim: int, pos: torch.Tensor, theta: float = 10000.0, linear_factor=1.0, ntk_factor=1.0):
|
| 377 |
+
"""
|
| 378 |
+
Compute 1D rotary position embeddings.
|
| 379 |
+
|
| 380 |
+
Args:
|
| 381 |
+
dim: Embedding dimension (must be even)
|
| 382 |
+
pos: Position tensor
|
| 383 |
+
theta: Base frequency for rotary embeddings
|
| 384 |
+
linear_factor: Linear scaling factor
|
| 385 |
+
ntk_factor: NTK (Neural Tangent Kernel) scaling factor
|
| 386 |
+
|
| 387 |
+
Returns:
|
| 388 |
+
Tuple of (cos_embeddings, sin_embeddings)
|
| 389 |
+
"""
|
| 390 |
+
assert dim % 2 == 0
|
| 391 |
+
theta = theta * ntk_factor
|
| 392 |
+
freqs = (
|
| 393 |
+
1.0
|
| 394 |
+
/ (theta ** (torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device)[: (dim // 2)] / dim))
|
| 395 |
+
/ linear_factor
|
| 396 |
+
)
|
| 397 |
+
freqs = torch.outer(pos, freqs)
|
| 398 |
+
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float()
|
| 399 |
+
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float()
|
| 400 |
+
return freqs_cos, freqs_sin
|
| 401 |
+
|
| 402 |
+
@staticmethod
|
| 403 |
+
def get_3d_rotary_pos_embed(position, embed_dim, voxel_resolution, theta: int = 10000):
|
| 404 |
+
"""
|
| 405 |
+
Compute 3D rotary position embeddings for spatial coordinates.
|
| 406 |
+
|
| 407 |
+
Args:
|
| 408 |
+
position: 3D position tensor with shape [..., 3]
|
| 409 |
+
embed_dim: Embedding dimension
|
| 410 |
+
voxel_resolution: Resolution of the voxel grid
|
| 411 |
+
theta: Base frequency for rotary embeddings
|
| 412 |
+
|
| 413 |
+
Returns:
|
| 414 |
+
Tuple of (cos_embeddings, sin_embeddings) for 3D positions
|
| 415 |
+
"""
|
| 416 |
+
assert position.shape[-1] == 3
|
| 417 |
+
dim_xy = embed_dim // 8 * 3
|
| 418 |
+
dim_z = embed_dim // 8 * 2
|
| 419 |
+
|
| 420 |
+
grid = torch.arange(voxel_resolution, dtype=torch.float32, device=position.device)
|
| 421 |
+
freqs_xy = RotaryEmbedding.get_1d_rotary_pos_embed(dim_xy, grid, theta=theta)
|
| 422 |
+
freqs_z = RotaryEmbedding.get_1d_rotary_pos_embed(dim_z, grid, theta=theta)
|
| 423 |
+
|
| 424 |
+
xy_cos, xy_sin = freqs_xy
|
| 425 |
+
z_cos, z_sin = freqs_z
|
| 426 |
+
|
| 427 |
+
embed_flattn = position.view(-1, position.shape[-1])
|
| 428 |
+
x_cos = xy_cos[embed_flattn[:, 0], :]
|
| 429 |
+
x_sin = xy_sin[embed_flattn[:, 0], :]
|
| 430 |
+
y_cos = xy_cos[embed_flattn[:, 1], :]
|
| 431 |
+
y_sin = xy_sin[embed_flattn[:, 1], :]
|
| 432 |
+
z_cos = z_cos[embed_flattn[:, 2], :]
|
| 433 |
+
z_sin = z_sin[embed_flattn[:, 2], :]
|
| 434 |
+
|
| 435 |
+
cos = torch.cat((x_cos, y_cos, z_cos), dim=-1)
|
| 436 |
+
sin = torch.cat((x_sin, y_sin, z_sin), dim=-1)
|
| 437 |
+
|
| 438 |
+
cos = cos.view(*position.shape[:-1], embed_dim)
|
| 439 |
+
sin = sin.view(*position.shape[:-1], embed_dim)
|
| 440 |
+
return cos, sin
|
| 441 |
+
|
| 442 |
+
@staticmethod
|
| 443 |
+
def apply_rotary_emb(x: torch.Tensor, freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]]):
|
| 444 |
+
"""
|
| 445 |
+
Apply rotary position embeddings to input tensor.
|
| 446 |
+
|
| 447 |
+
Args:
|
| 448 |
+
x: Input tensor to apply rotary embeddings to
|
| 449 |
+
freqs_cis: Tuple of (cos_embeddings, sin_embeddings) or single tensor
|
| 450 |
+
|
| 451 |
+
Returns:
|
| 452 |
+
Tensor with rotary position embeddings applied
|
| 453 |
+
"""
|
| 454 |
+
cos, sin = freqs_cis
|
| 455 |
+
cos, sin = cos.to(x.device), sin.to(x.device)
|
| 456 |
+
cos = cos.unsqueeze(1)
|
| 457 |
+
sin = sin.unsqueeze(1)
|
| 458 |
+
|
| 459 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1)
|
| 460 |
+
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
| 461 |
+
|
| 462 |
+
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
| 463 |
+
return out
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
# Core attention processing logic (eliminating major duplication)
|
| 467 |
+
class AttnCore:
|
| 468 |
+
"""
|
| 469 |
+
Core attention processing logic shared across processors.
|
| 470 |
+
|
| 471 |
+
This class provides the fundamental attention computation pipeline
|
| 472 |
+
that can be reused across different attention processor implementations.
|
| 473 |
+
"""
|
| 474 |
+
|
| 475 |
+
@staticmethod
|
| 476 |
+
def process_attention_base(
|
| 477 |
+
attn: Attention,
|
| 478 |
+
hidden_states: torch.Tensor,
|
| 479 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 480 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 481 |
+
temb: Optional[torch.Tensor] = None,
|
| 482 |
+
get_qkv_fn: Callable = None,
|
| 483 |
+
apply_rope_fn: Optional[Callable] = None,
|
| 484 |
+
**kwargs,
|
| 485 |
+
):
|
| 486 |
+
"""
|
| 487 |
+
Generic attention processing core shared across different processors.
|
| 488 |
+
|
| 489 |
+
This function implements the common attention computation pipeline including:
|
| 490 |
+
1. Hidden state preprocessing
|
| 491 |
+
2. Attention mask preparation
|
| 492 |
+
3. Q/K/V computation via provided function
|
| 493 |
+
4. Tensor reshaping for multi-head attention
|
| 494 |
+
5. Optional normalization and RoPE application
|
| 495 |
+
6. Scaled dot-product attention computation
|
| 496 |
+
|
| 497 |
+
Args:
|
| 498 |
+
attn: Attention module instance
|
| 499 |
+
hidden_states: Input hidden states tensor
|
| 500 |
+
encoder_hidden_states: Optional encoder hidden states for cross-attention
|
| 501 |
+
attention_mask: Optional attention mask tensor
|
| 502 |
+
temb: Optional temporal embedding tensor
|
| 503 |
+
get_qkv_fn: Function to compute Q, K, V tensors
|
| 504 |
+
apply_rope_fn: Optional function to apply rotary position embeddings
|
| 505 |
+
**kwargs: Additional keyword arguments passed to subfunctions
|
| 506 |
+
|
| 507 |
+
Returns:
|
| 508 |
+
Tuple containing (attention_output, residual, input_ndim, shape_info,
|
| 509 |
+
batch_size, num_heads, head_dim)
|
| 510 |
+
"""
|
| 511 |
+
# Prepare hidden states
|
| 512 |
+
hidden_states, residual, input_ndim, shape_info = AttnUtils.prepare_hidden_states(hidden_states, attn, temb)
|
| 513 |
+
|
| 514 |
+
batch_size, sequence_length, _ = (
|
| 515 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
# Prepare attention mask
|
| 519 |
+
attention_mask = AttnUtils.prepare_attention_mask(attention_mask, attn, sequence_length, batch_size)
|
| 520 |
+
|
| 521 |
+
# Get Q, K, V
|
| 522 |
+
if encoder_hidden_states is None:
|
| 523 |
+
encoder_hidden_states = hidden_states
|
| 524 |
+
elif attn.norm_cross:
|
| 525 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
| 526 |
+
|
| 527 |
+
query, key, value = get_qkv_fn(attn, hidden_states, encoder_hidden_states, **kwargs)
|
| 528 |
+
|
| 529 |
+
# Reshape for attention
|
| 530 |
+
inner_dim = key.shape[-1]
|
| 531 |
+
head_dim = inner_dim // attn.heads
|
| 532 |
+
|
| 533 |
+
query = AttnUtils.reshape_qkv_for_attention(query, batch_size, attn.heads, head_dim)
|
| 534 |
+
key = AttnUtils.reshape_qkv_for_attention(key, batch_size, attn.heads, head_dim)
|
| 535 |
+
value = AttnUtils.reshape_qkv_for_attention(value, batch_size, attn.heads, value.shape[-1] // attn.heads)
|
| 536 |
+
|
| 537 |
+
# Apply normalization
|
| 538 |
+
query, key = AttnUtils.apply_norms(query, key, getattr(attn, "norm_q", None), getattr(attn, "norm_k", None))
|
| 539 |
+
|
| 540 |
+
# Apply RoPE if provided
|
| 541 |
+
if apply_rope_fn is not None:
|
| 542 |
+
query, key = apply_rope_fn(query, key, head_dim, **kwargs)
|
| 543 |
+
|
| 544 |
+
# Compute attention
|
| 545 |
+
hidden_states = F.scaled_dot_product_attention(
|
| 546 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
return hidden_states, residual, input_ndim, shape_info, batch_size, attn.heads, head_dim
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
# Specific processor implementations (minimal unique code)
|
| 553 |
+
class PoseRoPEAttnProcessor2_0:
|
| 554 |
+
"""
|
| 555 |
+
Attention processor with Rotary Position Encoding (RoPE) for 3D spatial awareness.
|
| 556 |
+
|
| 557 |
+
This processor extends standard attention with 3D rotary position embeddings
|
| 558 |
+
to provide spatial awareness for 3D scene understanding tasks.
|
| 559 |
+
"""
|
| 560 |
+
|
| 561 |
+
def __init__(self):
|
| 562 |
+
"""Initialize the RoPE attention processor."""
|
| 563 |
+
AttnUtils.check_pytorch_compatibility()
|
| 564 |
+
|
| 565 |
+
def __call__(
|
| 566 |
+
self,
|
| 567 |
+
attn: Attention,
|
| 568 |
+
hidden_states: torch.Tensor,
|
| 569 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 570 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 571 |
+
position_indices: Dict = None,
|
| 572 |
+
temb: Optional[torch.Tensor] = None,
|
| 573 |
+
n_pbrs=1,
|
| 574 |
+
*args,
|
| 575 |
+
**kwargs,
|
| 576 |
+
) -> torch.Tensor:
|
| 577 |
+
"""
|
| 578 |
+
Apply RoPE-enhanced attention computation.
|
| 579 |
+
|
| 580 |
+
Args:
|
| 581 |
+
attn: Attention module instance
|
| 582 |
+
hidden_states: Input hidden states tensor
|
| 583 |
+
encoder_hidden_states: Optional encoder hidden states for cross-attention
|
| 584 |
+
attention_mask: Optional attention mask tensor
|
| 585 |
+
position_indices: Dictionary containing 3D position information for RoPE
|
| 586 |
+
temb: Optional temporal embedding tensor
|
| 587 |
+
n_pbrs: Number of PBR material types
|
| 588 |
+
*args: Additional positional arguments
|
| 589 |
+
**kwargs: Additional keyword arguments
|
| 590 |
+
|
| 591 |
+
Returns:
|
| 592 |
+
Attention output tensor with applied rotary position encodings
|
| 593 |
+
"""
|
| 594 |
+
AttnUtils.handle_deprecation_warning(args, kwargs)
|
| 595 |
+
|
| 596 |
+
def get_qkv(attn, hidden_states, encoder_hidden_states, **kwargs):
|
| 597 |
+
return attn.to_q(hidden_states), attn.to_k(encoder_hidden_states), attn.to_v(encoder_hidden_states)
|
| 598 |
+
|
| 599 |
+
def apply_rope(query, key, head_dim, **kwargs):
|
| 600 |
+
if position_indices is not None:
|
| 601 |
+
if head_dim in position_indices:
|
| 602 |
+
image_rotary_emb = position_indices[head_dim]
|
| 603 |
+
else:
|
| 604 |
+
image_rotary_emb = RotaryEmbedding.get_3d_rotary_pos_embed(
|
| 605 |
+
rearrange(
|
| 606 |
+
position_indices["voxel_indices"].unsqueeze(1).repeat(1, n_pbrs, 1, 1),
|
| 607 |
+
"b n_pbrs l c -> (b n_pbrs) l c",
|
| 608 |
+
),
|
| 609 |
+
head_dim,
|
| 610 |
+
voxel_resolution=position_indices["voxel_resolution"],
|
| 611 |
+
)
|
| 612 |
+
position_indices[head_dim] = image_rotary_emb
|
| 613 |
+
|
| 614 |
+
query = RotaryEmbedding.apply_rotary_emb(query, image_rotary_emb)
|
| 615 |
+
key = RotaryEmbedding.apply_rotary_emb(key, image_rotary_emb)
|
| 616 |
+
return query, key
|
| 617 |
+
|
| 618 |
+
# Core attention processing
|
| 619 |
+
hidden_states, residual, input_ndim, shape_info, batch_size, heads, head_dim = AttnCore.process_attention_base(
|
| 620 |
+
attn,
|
| 621 |
+
hidden_states,
|
| 622 |
+
encoder_hidden_states,
|
| 623 |
+
attention_mask,
|
| 624 |
+
temb,
|
| 625 |
+
get_qkv_fn=get_qkv,
|
| 626 |
+
apply_rope_fn=apply_rope,
|
| 627 |
+
position_indices=position_indices,
|
| 628 |
+
n_pbrs=n_pbrs,
|
| 629 |
+
)
|
| 630 |
+
|
| 631 |
+
# Finalize output
|
| 632 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, heads * head_dim)
|
| 633 |
+
hidden_states = hidden_states.to(hidden_states.dtype)
|
| 634 |
+
|
| 635 |
+
return AttnUtils.finalize_output(hidden_states, input_ndim, shape_info, attn, residual, attn.to_out)
|
| 636 |
+
|
| 637 |
+
|
| 638 |
+
class SelfAttnProcessor2_0(BaseAttnProcessor):
|
| 639 |
+
"""
|
| 640 |
+
Self-attention processor with PBR (Physically Based Rendering) material support.
|
| 641 |
+
|
| 642 |
+
This processor handles multiple PBR material types (e.g., albedo, metallic-roughness)
|
| 643 |
+
with separate attention computation paths for each material type.
|
| 644 |
+
"""
|
| 645 |
+
|
| 646 |
+
def __init__(self, **kwargs):
|
| 647 |
+
"""
|
| 648 |
+
Initialize self-attention processor with PBR support.
|
| 649 |
+
|
| 650 |
+
Args:
|
| 651 |
+
**kwargs: Arguments passed to BaseAttnProcessor initialization
|
| 652 |
+
"""
|
| 653 |
+
super().__init__(**kwargs)
|
| 654 |
+
self.register_pbr_modules(["qkv", "out", "add_kv"], **kwargs)
|
| 655 |
+
|
| 656 |
+
def process_single(
|
| 657 |
+
self,
|
| 658 |
+
attn: Attention,
|
| 659 |
+
hidden_states: torch.Tensor,
|
| 660 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 661 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 662 |
+
temb: Optional[torch.Tensor] = None,
|
| 663 |
+
token: Literal["albedo", "mr"] = "albedo",
|
| 664 |
+
multiple_devices=False,
|
| 665 |
+
*args,
|
| 666 |
+
**kwargs,
|
| 667 |
+
):
|
| 668 |
+
"""
|
| 669 |
+
Process attention for a single PBR material type.
|
| 670 |
+
|
| 671 |
+
Args:
|
| 672 |
+
attn: Attention module instance
|
| 673 |
+
hidden_states: Input hidden states tensor
|
| 674 |
+
encoder_hidden_states: Optional encoder hidden states for cross-attention
|
| 675 |
+
attention_mask: Optional attention mask tensor
|
| 676 |
+
temb: Optional temporal embedding tensor
|
| 677 |
+
token: PBR material type to process ("albedo", "mr", etc.)
|
| 678 |
+
multiple_devices: Whether to use multiple GPU devices
|
| 679 |
+
*args: Additional positional arguments
|
| 680 |
+
**kwargs: Additional keyword arguments
|
| 681 |
+
|
| 682 |
+
Returns:
|
| 683 |
+
Processed attention output for the specified PBR material type
|
| 684 |
+
"""
|
| 685 |
+
target = attn if token == "albedo" else attn.processor
|
| 686 |
+
token_suffix = "" if token == "albedo" else "_" + token
|
| 687 |
+
|
| 688 |
+
# Device management (if needed)
|
| 689 |
+
if multiple_devices:
|
| 690 |
+
device = torch.device("cuda:0") if token == "albedo" else torch.device("cuda:1")
|
| 691 |
+
for attr in [f"to_q{token_suffix}", f"to_k{token_suffix}", f"to_v{token_suffix}", f"to_out{token_suffix}"]:
|
| 692 |
+
getattr(target, attr).to(device)
|
| 693 |
+
|
| 694 |
+
def get_qkv(attn, hidden_states, encoder_hidden_states, **kwargs):
|
| 695 |
+
return (
|
| 696 |
+
getattr(target, f"to_q{token_suffix}")(hidden_states),
|
| 697 |
+
getattr(target, f"to_k{token_suffix}")(encoder_hidden_states),
|
| 698 |
+
getattr(target, f"to_v{token_suffix}")(encoder_hidden_states),
|
| 699 |
+
)
|
| 700 |
+
|
| 701 |
+
# Core processing using shared logic
|
| 702 |
+
hidden_states, residual, input_ndim, shape_info, batch_size, heads, head_dim = AttnCore.process_attention_base(
|
| 703 |
+
attn, hidden_states, encoder_hidden_states, attention_mask, temb, get_qkv_fn=get_qkv
|
| 704 |
+
)
|
| 705 |
+
|
| 706 |
+
# Finalize
|
| 707 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, heads * head_dim)
|
| 708 |
+
hidden_states = hidden_states.to(hidden_states.dtype)
|
| 709 |
+
|
| 710 |
+
return AttnUtils.finalize_output(
|
| 711 |
+
hidden_states, input_ndim, shape_info, attn, residual, getattr(target, f"to_out{token_suffix}")
|
| 712 |
+
)
|
| 713 |
+
|
| 714 |
+
def __call__(
|
| 715 |
+
self,
|
| 716 |
+
attn: Attention,
|
| 717 |
+
hidden_states: torch.Tensor,
|
| 718 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 719 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 720 |
+
temb: Optional[torch.Tensor] = None,
|
| 721 |
+
*args,
|
| 722 |
+
**kwargs,
|
| 723 |
+
) -> torch.Tensor:
|
| 724 |
+
"""
|
| 725 |
+
Apply self-attention with PBR material processing.
|
| 726 |
+
|
| 727 |
+
Processes multiple PBR material types sequentially, applying attention
|
| 728 |
+
computation for each material type separately and combining results.
|
| 729 |
+
|
| 730 |
+
Args:
|
| 731 |
+
attn: Attention module instance
|
| 732 |
+
hidden_states: Input hidden states tensor with PBR dimension
|
| 733 |
+
encoder_hidden_states: Optional encoder hidden states for cross-attention
|
| 734 |
+
attention_mask: Optional attention mask tensor
|
| 735 |
+
temb: Optional temporal embedding tensor
|
| 736 |
+
*args: Additional positional arguments
|
| 737 |
+
**kwargs: Additional keyword arguments
|
| 738 |
+
|
| 739 |
+
Returns:
|
| 740 |
+
Combined attention output for all PBR material types
|
| 741 |
+
"""
|
| 742 |
+
AttnUtils.handle_deprecation_warning(args, kwargs)
|
| 743 |
+
|
| 744 |
+
B = hidden_states.size(0)
|
| 745 |
+
pbr_hidden_states = torch.split(hidden_states, 1, dim=1)
|
| 746 |
+
|
| 747 |
+
# Process each PBR setting
|
| 748 |
+
results = []
|
| 749 |
+
for token, pbr_hs in zip(self.pbr_setting, pbr_hidden_states):
|
| 750 |
+
processed_hs = rearrange(pbr_hs, "b n_pbrs n l c -> (b n_pbrs n) l c").to("cuda:0")
|
| 751 |
+
result = self.process_single(attn, processed_hs, None, attention_mask, temb, token, False)
|
| 752 |
+
results.append(result)
|
| 753 |
+
|
| 754 |
+
outputs = [rearrange(result, "(b n_pbrs n) l c -> b n_pbrs n l c", b=B, n_pbrs=1) for result in results]
|
| 755 |
+
return torch.cat(outputs, dim=1)
|
| 756 |
+
|
| 757 |
+
|
| 758 |
+
class RefAttnProcessor2_0(BaseAttnProcessor):
|
| 759 |
+
"""
|
| 760 |
+
Reference attention processor with shared value computation across PBR materials.
|
| 761 |
+
|
| 762 |
+
This processor computes query and key once, but uses separate value projections
|
| 763 |
+
for different PBR material types, enabling efficient multi-material processing.
|
| 764 |
+
"""
|
| 765 |
+
|
| 766 |
+
def __init__(self, **kwargs):
|
| 767 |
+
"""
|
| 768 |
+
Initialize reference attention processor.
|
| 769 |
+
|
| 770 |
+
Args:
|
| 771 |
+
**kwargs: Arguments passed to BaseAttnProcessor initialization
|
| 772 |
+
"""
|
| 773 |
+
super().__init__(**kwargs)
|
| 774 |
+
self.pbr_settings = self.pbr_setting # Alias for compatibility
|
| 775 |
+
self.register_pbr_modules(["v_only", "out"], **kwargs)
|
| 776 |
+
|
| 777 |
+
def __call__(
|
| 778 |
+
self,
|
| 779 |
+
attn: Attention,
|
| 780 |
+
hidden_states: torch.Tensor,
|
| 781 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 782 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 783 |
+
temb: Optional[torch.Tensor] = None,
|
| 784 |
+
*args,
|
| 785 |
+
**kwargs,
|
| 786 |
+
) -> torch.Tensor:
|
| 787 |
+
"""
|
| 788 |
+
Apply reference attention with shared Q/K and separate V projections.
|
| 789 |
+
|
| 790 |
+
This method computes query and key tensors once and reuses them across
|
| 791 |
+
all PBR material types, while using separate value projections for each
|
| 792 |
+
material type to maintain material-specific information.
|
| 793 |
+
|
| 794 |
+
Args:
|
| 795 |
+
attn: Attention module instance
|
| 796 |
+
hidden_states: Input hidden states tensor
|
| 797 |
+
encoder_hidden_states: Optional encoder hidden states for cross-attention
|
| 798 |
+
attention_mask: Optional attention mask tensor
|
| 799 |
+
temb: Optional temporal embedding tensor
|
| 800 |
+
*args: Additional positional arguments
|
| 801 |
+
**kwargs: Additional keyword arguments
|
| 802 |
+
|
| 803 |
+
Returns:
|
| 804 |
+
Stacked attention output for all PBR material types
|
| 805 |
+
"""
|
| 806 |
+
AttnUtils.handle_deprecation_warning(args, kwargs)
|
| 807 |
+
|
| 808 |
+
def get_qkv(attn, hidden_states, encoder_hidden_states, **kwargs):
|
| 809 |
+
query = attn.to_q(hidden_states)
|
| 810 |
+
key = attn.to_k(encoder_hidden_states)
|
| 811 |
+
|
| 812 |
+
# Concatenate values from all PBR settings
|
| 813 |
+
value_list = [attn.to_v(encoder_hidden_states)]
|
| 814 |
+
for token in ["_" + token for token in self.pbr_settings if token != "albedo"]:
|
| 815 |
+
value_list.append(getattr(attn.processor, f"to_v{token}")(encoder_hidden_states))
|
| 816 |
+
value = torch.cat(value_list, dim=-1)
|
| 817 |
+
|
| 818 |
+
return query, key, value
|
| 819 |
+
|
| 820 |
+
# Core processing
|
| 821 |
+
hidden_states, residual, input_ndim, shape_info, batch_size, heads, head_dim = AttnCore.process_attention_base(
|
| 822 |
+
attn, hidden_states, encoder_hidden_states, attention_mask, temb, get_qkv_fn=get_qkv
|
| 823 |
+
)
|
| 824 |
+
|
| 825 |
+
# Split and process each PBR setting output
|
| 826 |
+
hidden_states_list = torch.split(hidden_states, head_dim, dim=-1)
|
| 827 |
+
output_hidden_states_list = []
|
| 828 |
+
|
| 829 |
+
for i, hs in enumerate(hidden_states_list):
|
| 830 |
+
hs = hs.transpose(1, 2).reshape(batch_size, -1, heads * head_dim).to(hs.dtype)
|
| 831 |
+
token_suffix = "_" + self.pbr_settings[i] if self.pbr_settings[i] != "albedo" else ""
|
| 832 |
+
target = attn if self.pbr_settings[i] == "albedo" else attn.processor
|
| 833 |
+
|
| 834 |
+
hs = AttnUtils.finalize_output(
|
| 835 |
+
hs, input_ndim, shape_info, attn, residual, getattr(target, f"to_out{token_suffix}")
|
| 836 |
+
)
|
| 837 |
+
output_hidden_states_list.append(hs)
|
| 838 |
+
|
| 839 |
+
return torch.stack(output_hidden_states_list, dim=1)
|
hunyuan3d-paintpbr-v2-1/unet/config.json
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "UNet2DConditionModel",
|
| 3 |
+
"_diffusers_version": "0.10.0.dev0",
|
| 4 |
+
"act_fn": "silu",
|
| 5 |
+
"attention_head_dim": [
|
| 6 |
+
5,
|
| 7 |
+
10,
|
| 8 |
+
20,
|
| 9 |
+
20
|
| 10 |
+
],
|
| 11 |
+
"block_out_channels": [
|
| 12 |
+
320,
|
| 13 |
+
640,
|
| 14 |
+
1280,
|
| 15 |
+
1280
|
| 16 |
+
],
|
| 17 |
+
"center_input_sample": false,
|
| 18 |
+
"cross_attention_dim": 1024,
|
| 19 |
+
"down_block_types": [
|
| 20 |
+
"CrossAttnDownBlock2D",
|
| 21 |
+
"CrossAttnDownBlock2D",
|
| 22 |
+
"CrossAttnDownBlock2D",
|
| 23 |
+
"DownBlock2D"
|
| 24 |
+
],
|
| 25 |
+
"downsample_padding": 1,
|
| 26 |
+
"dual_cross_attention": false,
|
| 27 |
+
"flip_sin_to_cos": true,
|
| 28 |
+
"freq_shift": 0,
|
| 29 |
+
"in_channels": 4,
|
| 30 |
+
"layers_per_block": 2,
|
| 31 |
+
"mid_block_scale_factor": 1,
|
| 32 |
+
"norm_eps": 1e-05,
|
| 33 |
+
"norm_num_groups": 32,
|
| 34 |
+
"num_class_embeds": null,
|
| 35 |
+
"only_cross_attention": false,
|
| 36 |
+
"out_channels": 4,
|
| 37 |
+
"sample_size": 64,
|
| 38 |
+
"up_block_types": [
|
| 39 |
+
"UpBlock2D",
|
| 40 |
+
"CrossAttnUpBlock2D",
|
| 41 |
+
"CrossAttnUpBlock2D",
|
| 42 |
+
"CrossAttnUpBlock2D"
|
| 43 |
+
],
|
| 44 |
+
"use_linear_projection": true
|
| 45 |
+
}
|
hunyuan3d-paintpbr-v2-1/unet/diffusion_pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:675a1b5cd0098b2002637c443946529c03c5cd54427f40245263350feb3dd5b8
|
| 3 |
+
size 3925293863
|
hunyuan3d-paintpbr-v2-1/unet/model.py
ADDED
|
@@ -0,0 +1,622 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
| 2 |
+
# except for the third-party components listed below.
|
| 3 |
+
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
| 4 |
+
# in the repsective licenses of these third-party components.
|
| 5 |
+
# Users must comply with all terms and conditions of original licenses of these third-party
|
| 6 |
+
# components and must ensure that the usage of the third party components adheres to
|
| 7 |
+
# all relevant laws and regulations.
|
| 8 |
+
|
| 9 |
+
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
| 10 |
+
# their software and algorithms, including trained model weights, parameters (including
|
| 11 |
+
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
| 12 |
+
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
| 13 |
+
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
|
| 17 |
+
# import ipdb
|
| 18 |
+
import numpy as np
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
import torch.nn.functional as F
|
| 22 |
+
import pytorch_lightning as pl
|
| 23 |
+
from tqdm import tqdm
|
| 24 |
+
from torchvision.transforms import v2
|
| 25 |
+
from torchvision.utils import make_grid, save_image
|
| 26 |
+
from einops import rearrange
|
| 27 |
+
|
| 28 |
+
from diffusers import (
|
| 29 |
+
DiffusionPipeline,
|
| 30 |
+
EulerAncestralDiscreteScheduler,
|
| 31 |
+
DDPMScheduler,
|
| 32 |
+
UNet2DConditionModel,
|
| 33 |
+
ControlNetModel,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
from .modules import Dino_v2, UNet2p5DConditionModel
|
| 37 |
+
import math
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def extract_into_tensor(a, t, x_shape):
|
| 41 |
+
b, *_ = t.shape
|
| 42 |
+
out = a.gather(-1, t)
|
| 43 |
+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class HunyuanPaint(pl.LightningModule):
|
| 47 |
+
def __init__(
|
| 48 |
+
self,
|
| 49 |
+
stable_diffusion_config,
|
| 50 |
+
control_net_config=None,
|
| 51 |
+
num_view=6,
|
| 52 |
+
view_size=320,
|
| 53 |
+
drop_cond_prob=0.1,
|
| 54 |
+
with_normal_map=None,
|
| 55 |
+
with_position_map=None,
|
| 56 |
+
pbr_settings=["albedo", "mr"],
|
| 57 |
+
**kwargs,
|
| 58 |
+
):
|
| 59 |
+
"""Initializes the HunyuanPaint Lightning Module.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
stable_diffusion_config: Configuration for loading the Stable Diffusion pipeline
|
| 63 |
+
control_net_config: Configuration for ControlNet (optional)
|
| 64 |
+
num_view: Number of views to process
|
| 65 |
+
view_size: Size of input views (height/width)
|
| 66 |
+
drop_cond_prob: Probability of dropping conditioning input during training
|
| 67 |
+
with_normal_map: Flag indicating whether normal maps are used
|
| 68 |
+
with_position_map: Flag indicating whether position maps are used
|
| 69 |
+
pbr_settings: List of PBR materials to generate (e.g., albedo, metallic-roughness)
|
| 70 |
+
**kwargs: Additional keyword arguments
|
| 71 |
+
"""
|
| 72 |
+
super(HunyuanPaint, self).__init__()
|
| 73 |
+
|
| 74 |
+
self.num_view = num_view
|
| 75 |
+
self.view_size = view_size
|
| 76 |
+
self.drop_cond_prob = drop_cond_prob
|
| 77 |
+
self.pbr_settings = pbr_settings
|
| 78 |
+
|
| 79 |
+
# init modules
|
| 80 |
+
pipeline = DiffusionPipeline.from_pretrained(**stable_diffusion_config)
|
| 81 |
+
pipeline.set_pbr_settings(self.pbr_settings)
|
| 82 |
+
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
| 83 |
+
pipeline.scheduler.config, timestep_spacing="trailing"
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
self.with_normal_map = with_normal_map
|
| 87 |
+
self.with_position_map = with_position_map
|
| 88 |
+
|
| 89 |
+
self.pipeline = pipeline
|
| 90 |
+
|
| 91 |
+
self.pipeline.vae.use_slicing = True
|
| 92 |
+
|
| 93 |
+
train_sched = DDPMScheduler.from_config(self.pipeline.scheduler.config)
|
| 94 |
+
|
| 95 |
+
if isinstance(self.pipeline.unet, UNet2DConditionModel):
|
| 96 |
+
self.pipeline.unet = UNet2p5DConditionModel(
|
| 97 |
+
self.pipeline.unet, train_sched, self.pipeline.scheduler, self.pbr_settings
|
| 98 |
+
)
|
| 99 |
+
self.train_scheduler = train_sched # use ddpm scheduler during training
|
| 100 |
+
|
| 101 |
+
self.register_schedule()
|
| 102 |
+
|
| 103 |
+
pipeline.set_learned_parameters()
|
| 104 |
+
|
| 105 |
+
if control_net_config is not None:
|
| 106 |
+
pipeline.unet = pipeline.unet.bfloat16().requires_grad_(control_net_config.train_unet)
|
| 107 |
+
self.pipeline.add_controlnet(
|
| 108 |
+
ControlNetModel.from_pretrained(control_net_config.pretrained_model_name_or_path),
|
| 109 |
+
conditioning_scale=0.75,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
self.unet = pipeline.unet
|
| 113 |
+
|
| 114 |
+
self.pipeline.set_progress_bar_config(disable=True)
|
| 115 |
+
self.pipeline.vae = self.pipeline.vae.bfloat16()
|
| 116 |
+
self.pipeline.text_encoder = self.pipeline.text_encoder.bfloat16()
|
| 117 |
+
|
| 118 |
+
if self.unet.use_dino:
|
| 119 |
+
self.dino_v2 = Dino_v2("facebook/dinov2-giant")
|
| 120 |
+
self.dino_v2 = self.dino_v2.bfloat16()
|
| 121 |
+
|
| 122 |
+
self.validation_step_outputs = []
|
| 123 |
+
|
| 124 |
+
def register_schedule(self):
|
| 125 |
+
|
| 126 |
+
self.num_timesteps = self.train_scheduler.config.num_train_timesteps
|
| 127 |
+
|
| 128 |
+
betas = self.train_scheduler.betas.detach().cpu()
|
| 129 |
+
|
| 130 |
+
alphas = 1.0 - betas
|
| 131 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
| 132 |
+
alphas_cumprod_prev = torch.cat([torch.ones(1, dtype=torch.float64), alphas_cumprod[:-1]], 0)
|
| 133 |
+
|
| 134 |
+
self.register_buffer("betas", betas.float())
|
| 135 |
+
self.register_buffer("alphas_cumprod", alphas_cumprod.float())
|
| 136 |
+
self.register_buffer("alphas_cumprod_prev", alphas_cumprod_prev.float())
|
| 137 |
+
|
| 138 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
| 139 |
+
self.register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod).float())
|
| 140 |
+
self.register_buffer("sqrt_one_minus_alphas_cumprod", torch.sqrt(1 - alphas_cumprod).float())
|
| 141 |
+
|
| 142 |
+
self.register_buffer("sqrt_recip_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod).float())
|
| 143 |
+
self.register_buffer("sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod - 1).float())
|
| 144 |
+
|
| 145 |
+
def on_fit_start(self):
|
| 146 |
+
device = torch.device(f"cuda:{self.local_rank}")
|
| 147 |
+
self.pipeline.to(device)
|
| 148 |
+
if self.global_rank == 0:
|
| 149 |
+
os.makedirs(os.path.join(self.logdir, "images_val"), exist_ok=True)
|
| 150 |
+
|
| 151 |
+
def prepare_batch_data(self, batch):
|
| 152 |
+
"""Preprocesses a batch of input data for training/inference.
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
batch: Raw input batch dictionary
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
tuple: Contains:
|
| 159 |
+
- cond_imgs: Primary conditioning images (B, 1, C, H, W)
|
| 160 |
+
- cond_imgs_another: Secondary conditioning images (B, 1, C, H, W)
|
| 161 |
+
- target_imgs: Dictionary of target PBR images resized and clamped
|
| 162 |
+
- images_normal: Preprocessed normal maps (if available)
|
| 163 |
+
- images_position: Preprocessed position maps (if available)
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
images_cond = batch["images_cond"].to(self.device) # (B, M, C, H, W), where M is the number of reference images
|
| 167 |
+
cond_imgs, cond_imgs_another = images_cond[:, 0:1, ...], images_cond[:, 1:2, ...]
|
| 168 |
+
|
| 169 |
+
cond_size = self.view_size
|
| 170 |
+
cond_imgs = v2.functional.resize(cond_imgs, cond_size, interpolation=3, antialias=True).clamp(0, 1)
|
| 171 |
+
cond_imgs_another = v2.functional.resize(cond_imgs_another, cond_size, interpolation=3, antialias=True).clamp(
|
| 172 |
+
0, 1
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
target_imgs = {}
|
| 176 |
+
for pbr_token in self.pbr_settings:
|
| 177 |
+
target_imgs[pbr_token] = batch[f"images_{pbr_token}"].to(self.device)
|
| 178 |
+
target_imgs[pbr_token] = v2.functional.resize(
|
| 179 |
+
target_imgs[pbr_token], self.view_size, interpolation=3, antialias=True
|
| 180 |
+
).clamp(0, 1)
|
| 181 |
+
|
| 182 |
+
images_normal = None
|
| 183 |
+
if "images_normal" in batch:
|
| 184 |
+
images_normal = batch["images_normal"] # (B, N, C, H, W)
|
| 185 |
+
images_normal = v2.functional.resize(images_normal, self.view_size, interpolation=3, antialias=True).clamp(
|
| 186 |
+
0, 1
|
| 187 |
+
)
|
| 188 |
+
images_normal = [images_normal]
|
| 189 |
+
|
| 190 |
+
images_position = None
|
| 191 |
+
if "images_position" in batch:
|
| 192 |
+
images_position = batch["images_position"] # (B, N, C, H, W)
|
| 193 |
+
images_position = v2.functional.resize(
|
| 194 |
+
images_position, self.view_size, interpolation=3, antialias=True
|
| 195 |
+
).clamp(0, 1)
|
| 196 |
+
images_position = [images_position]
|
| 197 |
+
|
| 198 |
+
return cond_imgs, cond_imgs_another, target_imgs, images_normal, images_position
|
| 199 |
+
|
| 200 |
+
@torch.no_grad()
|
| 201 |
+
def forward_text_encoder(self, prompts):
|
| 202 |
+
device = next(self.pipeline.vae.parameters()).device
|
| 203 |
+
text_embeds = self.pipeline.encode_prompt(prompts, device, 1, False)[0]
|
| 204 |
+
return text_embeds
|
| 205 |
+
|
| 206 |
+
@torch.no_grad()
|
| 207 |
+
def encode_images(self, images):
|
| 208 |
+
"""Encodes input images into latent representations using the VAE.
|
| 209 |
+
|
| 210 |
+
Handles both standard input (B, N, C, H, W) and PBR input (B, N_pbrs, N, C, H, W)
|
| 211 |
+
Maintains original batch structure in output latents.
|
| 212 |
+
|
| 213 |
+
Args:
|
| 214 |
+
images: Input images tensor
|
| 215 |
+
|
| 216 |
+
Returns:
|
| 217 |
+
torch.Tensor: Latent representations with original batch dimensions preserved
|
| 218 |
+
"""
|
| 219 |
+
|
| 220 |
+
B = images.shape[0]
|
| 221 |
+
image_ndims = images.ndim
|
| 222 |
+
if image_ndims != 5:
|
| 223 |
+
N_pbrs, N = images.shape[1:3]
|
| 224 |
+
images = (
|
| 225 |
+
rearrange(images, "b n c h w -> (b n) c h w")
|
| 226 |
+
if image_ndims == 5
|
| 227 |
+
else rearrange(images, "b n_pbrs n c h w -> (b n_pbrs n) c h w")
|
| 228 |
+
)
|
| 229 |
+
dtype = next(self.pipeline.vae.parameters()).dtype
|
| 230 |
+
|
| 231 |
+
images = (images - 0.5) * 2.0
|
| 232 |
+
posterior = self.pipeline.vae.encode(images.to(dtype)).latent_dist
|
| 233 |
+
latents = posterior.sample() * self.pipeline.vae.config.scaling_factor
|
| 234 |
+
|
| 235 |
+
latents = (
|
| 236 |
+
rearrange(latents, "(b n) c h w -> b n c h w", b=B)
|
| 237 |
+
if image_ndims == 5
|
| 238 |
+
else rearrange(latents, "(b n_pbrs n) c h w -> b n_pbrs n c h w", b=B, n_pbrs=N_pbrs)
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
return latents
|
| 242 |
+
|
| 243 |
+
def forward_unet(self, latents, t, **cached_condition):
|
| 244 |
+
"""Runs the UNet model to predict noise/latent residuals.
|
| 245 |
+
|
| 246 |
+
Args:
|
| 247 |
+
latents: Noisy latent representations (B, C, H, W)
|
| 248 |
+
t: Timestep tensor (B,)
|
| 249 |
+
**cached_condition: Dictionary of conditioning inputs (text embeds, reference images, etc)
|
| 250 |
+
|
| 251 |
+
Returns:
|
| 252 |
+
torch.Tensor: UNet output (predicted noise or velocity)
|
| 253 |
+
"""
|
| 254 |
+
|
| 255 |
+
dtype = next(self.unet.parameters()).dtype
|
| 256 |
+
latents = latents.to(dtype)
|
| 257 |
+
shading_embeds = cached_condition["shading_embeds"]
|
| 258 |
+
pred_noise = self.pipeline.unet(latents, t, encoder_hidden_states=shading_embeds, **cached_condition)
|
| 259 |
+
return pred_noise[0]
|
| 260 |
+
|
| 261 |
+
def predict_start_from_z_and_v(self, x_t, t, v):
|
| 262 |
+
"""
|
| 263 |
+
Predicts clean image (x0) from noisy latents (x_t) and
|
| 264 |
+
velocity prediction (v) using the v-prediction formula.
|
| 265 |
+
|
| 266 |
+
Args:
|
| 267 |
+
x_t: Noisy latents at timestep t
|
| 268 |
+
t: Current timestep
|
| 269 |
+
v: Predicted velocity (v) from UNet
|
| 270 |
+
|
| 271 |
+
Returns:
|
| 272 |
+
torch.Tensor: Predicted clean image (x0)
|
| 273 |
+
"""
|
| 274 |
+
|
| 275 |
+
return (
|
| 276 |
+
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t
|
| 277 |
+
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
def get_v(self, x, noise, t):
|
| 281 |
+
"""Computes the target velocity (v) for v-prediction training.
|
| 282 |
+
|
| 283 |
+
Args:
|
| 284 |
+
x: Clean latents (x0)
|
| 285 |
+
noise: Added noise
|
| 286 |
+
t: Current timestep
|
| 287 |
+
|
| 288 |
+
Returns:
|
| 289 |
+
torch.Tensor: Target velocity
|
| 290 |
+
"""
|
| 291 |
+
|
| 292 |
+
return (
|
| 293 |
+
extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise
|
| 294 |
+
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
def training_step(self, batch, batch_idx):
|
| 298 |
+
"""Performs a single training step with both conditioning paths.
|
| 299 |
+
|
| 300 |
+
Implements:
|
| 301 |
+
1. Dual-conditioning path training (main ref + secondary ref)
|
| 302 |
+
2. Velocity-prediction with consistency loss
|
| 303 |
+
3. Conditional dropout for robust learning
|
| 304 |
+
4. PBR-specific losses (albedo/metallic-roughness)
|
| 305 |
+
|
| 306 |
+
Args:
|
| 307 |
+
batch: Input batch from dataloader
|
| 308 |
+
batch_idx: Index of current batch
|
| 309 |
+
|
| 310 |
+
Returns:
|
| 311 |
+
torch.Tensor: Combined loss value
|
| 312 |
+
"""
|
| 313 |
+
|
| 314 |
+
cond_imgs, cond_imgs_another, target_imgs, normal_imgs, position_imgs = self.prepare_batch_data(batch)
|
| 315 |
+
|
| 316 |
+
B, N_ref = cond_imgs.shape[:2]
|
| 317 |
+
_, N_gen, _, H, W = target_imgs["albedo"].shape
|
| 318 |
+
N_pbrs = len(self.pbr_settings)
|
| 319 |
+
t = torch.randint(0, self.num_timesteps, size=(B,)).long().to(self.device)
|
| 320 |
+
t = t.unsqueeze(-1).repeat(1, N_pbrs, N_gen)
|
| 321 |
+
t = rearrange(t, "b n_pbrs n -> (b n_pbrs n)")
|
| 322 |
+
|
| 323 |
+
all_target_pbrs = []
|
| 324 |
+
for pbr_token in self.pbr_settings:
|
| 325 |
+
all_target_pbrs.append(target_imgs[pbr_token])
|
| 326 |
+
all_target_pbrs = torch.stack(all_target_pbrs, dim=0).transpose(1, 0)
|
| 327 |
+
gen_latents = self.encode_images(all_target_pbrs) #! B, N_pbrs N C H W
|
| 328 |
+
ref_latents = self.encode_images(cond_imgs) #! B, M, C, H, W
|
| 329 |
+
ref_latents_another = self.encode_images(cond_imgs_another) #! B, M, C, H, W
|
| 330 |
+
|
| 331 |
+
all_shading_tokens = []
|
| 332 |
+
for token in self.pbr_settings:
|
| 333 |
+
if token in ["albedo", "mr"]:
|
| 334 |
+
all_shading_tokens.append(
|
| 335 |
+
getattr(self.unet, f"learned_text_clip_{token}").unsqueeze(dim=0).repeat(B, 1, 1)
|
| 336 |
+
)
|
| 337 |
+
shading_embeds = torch.stack(all_shading_tokens, dim=1)
|
| 338 |
+
|
| 339 |
+
if self.unet.use_dino:
|
| 340 |
+
dino_hidden_states = self.dino_v2(cond_imgs[:, :1, ...])
|
| 341 |
+
dino_hidden_states_another = self.dino_v2(cond_imgs_another[:, :1, ...])
|
| 342 |
+
|
| 343 |
+
gen_latents = rearrange(gen_latents, "b n_pbrs n c h w -> (b n_pbrs n) c h w")
|
| 344 |
+
noise = torch.randn_like(gen_latents).to(self.device)
|
| 345 |
+
latents_noisy = self.train_scheduler.add_noise(gen_latents, noise, t).to(self.device)
|
| 346 |
+
latents_noisy = rearrange(latents_noisy, "(b n_pbrs n) c h w -> b n_pbrs n c h w", b=B, n_pbrs=N_pbrs)
|
| 347 |
+
|
| 348 |
+
cached_condition = {}
|
| 349 |
+
|
| 350 |
+
if normal_imgs is not None:
|
| 351 |
+
normal_embeds = self.encode_images(normal_imgs[0])
|
| 352 |
+
cached_condition["embeds_normal"] = normal_embeds #! B, N, C, H, W
|
| 353 |
+
|
| 354 |
+
if position_imgs is not None:
|
| 355 |
+
position_embeds = self.encode_images(position_imgs[0])
|
| 356 |
+
cached_condition["embeds_position"] = position_embeds #! B, N, C, H, W
|
| 357 |
+
cached_condition["position_maps"] = position_imgs[0] #! B, N, C, H, W
|
| 358 |
+
|
| 359 |
+
for b in range(B):
|
| 360 |
+
prob = np.random.rand()
|
| 361 |
+
if prob < self.drop_cond_prob:
|
| 362 |
+
if "normal_imgs" in cached_condition:
|
| 363 |
+
cached_condition["embeds_normal"][b, ...] = torch.zeros_like(
|
| 364 |
+
cached_condition["embeds_normal"][b, ...]
|
| 365 |
+
)
|
| 366 |
+
if "position_imgs" in cached_condition:
|
| 367 |
+
cached_condition["embeds_position"][b, ...] = torch.zeros_like(
|
| 368 |
+
cached_condition["embeds_position"][b, ...]
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
prob = np.random.rand()
|
| 372 |
+
if prob < self.drop_cond_prob:
|
| 373 |
+
if "position_maps" in cached_condition:
|
| 374 |
+
cached_condition["position_maps"][b, ...] = torch.zeros_like(
|
| 375 |
+
cached_condition["position_maps"][b, ...]
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
prob = np.random.rand()
|
| 379 |
+
if prob < self.drop_cond_prob:
|
| 380 |
+
dino_hidden_states[b, ...] = torch.zeros_like(dino_hidden_states[b, ...])
|
| 381 |
+
prob = np.random.rand()
|
| 382 |
+
if prob < self.drop_cond_prob:
|
| 383 |
+
dino_hidden_states_another[b, ...] = torch.zeros_like(dino_hidden_states_another[b, ...])
|
| 384 |
+
|
| 385 |
+
# MVA & Ref Attention
|
| 386 |
+
prob = np.random.rand()
|
| 387 |
+
cached_condition["mva_scale"] = 1.0
|
| 388 |
+
cached_condition["ref_scale"] = 1.0
|
| 389 |
+
if prob < self.drop_cond_prob:
|
| 390 |
+
cached_condition["mva_scale"] = 0.0
|
| 391 |
+
cached_condition["ref_scale"] = 0.0
|
| 392 |
+
elif prob > 1.0 - self.drop_cond_prob:
|
| 393 |
+
prob = np.random.rand()
|
| 394 |
+
if prob < 0.5:
|
| 395 |
+
cached_condition["mva_scale"] = 0.0
|
| 396 |
+
else:
|
| 397 |
+
cached_condition["ref_scale"] = 0.0
|
| 398 |
+
else:
|
| 399 |
+
pass
|
| 400 |
+
|
| 401 |
+
if self.train_scheduler.config.prediction_type == "v_prediction":
|
| 402 |
+
|
| 403 |
+
cached_condition["shading_embeds"] = shading_embeds
|
| 404 |
+
cached_condition["ref_latents"] = ref_latents
|
| 405 |
+
cached_condition["dino_hidden_states"] = dino_hidden_states
|
| 406 |
+
v_pred = self.forward_unet(latents_noisy, t, **cached_condition)
|
| 407 |
+
v_pred_albedo, v_pred_mr = torch.split(
|
| 408 |
+
rearrange(
|
| 409 |
+
v_pred, "(b n_pbr n) c h w -> b n_pbr n c h w", n_pbr=len(self.pbr_settings), n=self.num_view
|
| 410 |
+
),
|
| 411 |
+
1,
|
| 412 |
+
dim=1,
|
| 413 |
+
)
|
| 414 |
+
v_target = self.get_v(gen_latents, noise, t)
|
| 415 |
+
v_target_albedo, v_target_mr = torch.split(
|
| 416 |
+
rearrange(
|
| 417 |
+
v_target, "(b n_pbr n) c h w -> b n_pbr n c h w", n_pbr=len(self.pbr_settings), n=self.num_view
|
| 418 |
+
),
|
| 419 |
+
1,
|
| 420 |
+
dim=1,
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
albedo_loss_1, _ = self.compute_loss(v_pred_albedo, v_target_albedo)
|
| 424 |
+
mr_loss_1, _ = self.compute_loss(v_pred_mr, v_target_mr)
|
| 425 |
+
|
| 426 |
+
cached_condition["ref_latents"] = ref_latents_another
|
| 427 |
+
cached_condition["dino_hidden_states"] = dino_hidden_states_another
|
| 428 |
+
v_pred_another = self.forward_unet(latents_noisy, t, **cached_condition)
|
| 429 |
+
v_pred_another_albedo, v_pred_another_mr = torch.split(
|
| 430 |
+
rearrange(
|
| 431 |
+
v_pred_another,
|
| 432 |
+
"(b n_pbr n) c h w -> b n_pbr n c h w",
|
| 433 |
+
n_pbr=len(self.pbr_settings),
|
| 434 |
+
n=self.num_view,
|
| 435 |
+
),
|
| 436 |
+
1,
|
| 437 |
+
dim=1,
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
albedo_loss_2, _ = self.compute_loss(v_pred_another_albedo, v_target_albedo)
|
| 441 |
+
mr_loss_2, _ = self.compute_loss(v_pred_another_mr, v_target_mr)
|
| 442 |
+
|
| 443 |
+
consistency_loss, _ = self.compute_loss(v_pred_another, v_pred)
|
| 444 |
+
|
| 445 |
+
albedo_loss = (albedo_loss_1 + albedo_loss_2) * 0.5
|
| 446 |
+
mr_loss = (mr_loss_1 + mr_loss_2) * 0.5
|
| 447 |
+
|
| 448 |
+
log_loss_dict = {}
|
| 449 |
+
log_loss_dict.update({f"train/albedo_loss": albedo_loss})
|
| 450 |
+
log_loss_dict.update({f"train/mr_loss": mr_loss})
|
| 451 |
+
log_loss_dict.update({f"train/cons_loss": consistency_loss})
|
| 452 |
+
|
| 453 |
+
loss_dict = log_loss_dict
|
| 454 |
+
|
| 455 |
+
elif self.train_scheduler.config.prediction_type == "epsilon":
|
| 456 |
+
e_pred = self.forward_unet(latents_noisy, t, **cached_condition)
|
| 457 |
+
loss, loss_dict = self.compute_loss(e_pred, noise)
|
| 458 |
+
else:
|
| 459 |
+
raise f"No {self.train_scheduler.config.prediction_type}"
|
| 460 |
+
|
| 461 |
+
# logging
|
| 462 |
+
self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
| 463 |
+
self.log("global_step", self.global_step, prog_bar=True, logger=True, on_step=True, on_epoch=False)
|
| 464 |
+
lr = self.optimizers().param_groups[0]["lr"]
|
| 465 |
+
self.log("lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
|
| 466 |
+
|
| 467 |
+
return 0.85 * (albedo_loss + mr_loss) + 0.15 * consistency_loss
|
| 468 |
+
|
| 469 |
+
def compute_loss(self, noise_pred, noise_gt):
|
| 470 |
+
loss = F.mse_loss(noise_pred, noise_gt)
|
| 471 |
+
prefix = "train"
|
| 472 |
+
loss_dict = {}
|
| 473 |
+
loss_dict.update({f"{prefix}/loss": loss})
|
| 474 |
+
return loss, loss_dict
|
| 475 |
+
|
| 476 |
+
@torch.no_grad()
|
| 477 |
+
def validation_step(self, batch, batch_idx):
|
| 478 |
+
"""Performs validation on a single batch.
|
| 479 |
+
|
| 480 |
+
Generates predicted images using:
|
| 481 |
+
1. Reference conditioning images
|
| 482 |
+
2. Optional normal/position maps
|
| 483 |
+
3. Frozen DINO features (if enabled)
|
| 484 |
+
4. Text prompt conditioning
|
| 485 |
+
|
| 486 |
+
Compares predictions against ground truth targets and prepares visualization.
|
| 487 |
+
Stores results for epoch-level aggregation.
|
| 488 |
+
|
| 489 |
+
Args:
|
| 490 |
+
batch: Input batch from validation dataloader
|
| 491 |
+
batch_idx: Index of current batch
|
| 492 |
+
"""
|
| 493 |
+
# [Validation image generation and comparison logic...]
|
| 494 |
+
# Key steps:
|
| 495 |
+
# 1. Preprocess conditioning images to PIL format
|
| 496 |
+
# 2. Set up conditioning inputs (normal maps, position maps, DINO features)
|
| 497 |
+
# 3. Run pipeline inference with fixed prompt ("high quality")
|
| 498 |
+
# 4. Decode latent outputs to image space
|
| 499 |
+
# 5. Arrange predictions and ground truths for visualization
|
| 500 |
+
|
| 501 |
+
cond_imgs_tensor, _, target_imgs, normal_imgs, position_imgs = self.prepare_batch_data(batch)
|
| 502 |
+
resolution = self.view_size
|
| 503 |
+
image_pils = []
|
| 504 |
+
for i in range(cond_imgs_tensor.shape[0]):
|
| 505 |
+
image_pils.append([])
|
| 506 |
+
for j in range(cond_imgs_tensor.shape[1]):
|
| 507 |
+
image_pils[-1].append(v2.functional.to_pil_image(cond_imgs_tensor[i, j, ...]))
|
| 508 |
+
|
| 509 |
+
outputs, gts = [], []
|
| 510 |
+
for idx in range(len(image_pils)):
|
| 511 |
+
cond_imgs = image_pils[idx]
|
| 512 |
+
|
| 513 |
+
cached_condition = dict(num_in_batch=self.num_view, N_pbrs=len(self.pbr_settings))
|
| 514 |
+
if normal_imgs is not None:
|
| 515 |
+
cached_condition["images_normal"] = normal_imgs[0][idx, ...].unsqueeze(0)
|
| 516 |
+
if position_imgs is not None:
|
| 517 |
+
cached_condition["images_position"] = position_imgs[0][idx, ...].unsqueeze(0)
|
| 518 |
+
if self.pipeline.unet.use_dino:
|
| 519 |
+
dino_hidden_states = self.dino_v2([cond_imgs][0])
|
| 520 |
+
cached_condition["dino_hidden_states"] = dino_hidden_states
|
| 521 |
+
|
| 522 |
+
latent = self.pipeline(
|
| 523 |
+
cond_imgs,
|
| 524 |
+
prompt="high quality",
|
| 525 |
+
num_inference_steps=30,
|
| 526 |
+
output_type="latent",
|
| 527 |
+
height=resolution,
|
| 528 |
+
width=resolution,
|
| 529 |
+
**cached_condition,
|
| 530 |
+
).images
|
| 531 |
+
|
| 532 |
+
image = self.pipeline.vae.decode(latent / self.pipeline.vae.config.scaling_factor, return_dict=False)[
|
| 533 |
+
0
|
| 534 |
+
] # [-1, 1]
|
| 535 |
+
image = (image * 0.5 + 0.5).clamp(0, 1)
|
| 536 |
+
|
| 537 |
+
image = rearrange(
|
| 538 |
+
image, "(b n_pbr n) c h w -> b n_pbr n c h w", n_pbr=len(self.pbr_settings), n=self.num_view
|
| 539 |
+
)
|
| 540 |
+
image = torch.cat((torch.ones_like(image[:, :, :1, ...]) * 0.5, image), dim=2)
|
| 541 |
+
image = rearrange(image, "b n_pbr n c h w -> (b n_pbr n) c h w")
|
| 542 |
+
image = rearrange(
|
| 543 |
+
image,
|
| 544 |
+
"(b n_pbr n) c h w -> b c (n_pbr h) (n w)",
|
| 545 |
+
b=1,
|
| 546 |
+
n_pbr=len(self.pbr_settings),
|
| 547 |
+
n=self.num_view + 1,
|
| 548 |
+
)
|
| 549 |
+
outputs.append(image)
|
| 550 |
+
|
| 551 |
+
all_target_pbrs = []
|
| 552 |
+
for pbr_token in self.pbr_settings:
|
| 553 |
+
all_target_pbrs.append(target_imgs[pbr_token])
|
| 554 |
+
all_target_pbrs = torch.stack(all_target_pbrs, dim=0).transpose(1, 0)
|
| 555 |
+
all_target_pbrs = torch.cat(
|
| 556 |
+
(cond_imgs_tensor.unsqueeze(1).repeat(1, len(self.pbr_settings), 1, 1, 1, 1), all_target_pbrs), dim=2
|
| 557 |
+
)
|
| 558 |
+
all_target_pbrs = rearrange(all_target_pbrs, "b n_pbrs n c h w -> b c (n_pbrs h) (n w)")
|
| 559 |
+
gts = all_target_pbrs
|
| 560 |
+
outputs = torch.cat(outputs, dim=0).to(self.device)
|
| 561 |
+
images = torch.cat([gts, outputs], dim=-2)
|
| 562 |
+
self.validation_step_outputs.append(images)
|
| 563 |
+
|
| 564 |
+
@torch.no_grad()
|
| 565 |
+
def on_validation_epoch_end(self):
|
| 566 |
+
"""Aggregates validation results at epoch end.
|
| 567 |
+
|
| 568 |
+
Gathers outputs from all GPUs (if distributed training),
|
| 569 |
+
creates a unified visualization grid, and saves to disk.
|
| 570 |
+
Only rank 0 process performs saving.
|
| 571 |
+
"""
|
| 572 |
+
# [Result aggregation and visualization...]
|
| 573 |
+
# Key steps:
|
| 574 |
+
# 1. Gather validation outputs from all processes
|
| 575 |
+
# 2. Create image grid combining ground truths and predictions
|
| 576 |
+
# 3. Save visualization with step-numbered filename
|
| 577 |
+
# 4. Clear memory for next validation cycle
|
| 578 |
+
|
| 579 |
+
images = torch.cat(self.validation_step_outputs, dim=0)
|
| 580 |
+
all_images = self.all_gather(images)
|
| 581 |
+
all_images = rearrange(all_images, "r b c h w -> (r b) c h w")
|
| 582 |
+
|
| 583 |
+
if self.global_rank == 0:
|
| 584 |
+
grid = make_grid(all_images, nrow=8, normalize=True, value_range=(0, 1))
|
| 585 |
+
save_image(grid, os.path.join(self.logdir, "images_val", f"val_{self.global_step:07d}.png"))
|
| 586 |
+
|
| 587 |
+
self.validation_step_outputs.clear() # free memory
|
| 588 |
+
|
| 589 |
+
def configure_optimizers(self):
|
| 590 |
+
lr = self.learning_rate
|
| 591 |
+
optimizer = torch.optim.AdamW(self.unet.parameters(), lr=lr)
|
| 592 |
+
|
| 593 |
+
def lr_lambda(step):
|
| 594 |
+
warm_up_step = 1000
|
| 595 |
+
T_step = 9000
|
| 596 |
+
gamma = 0.9
|
| 597 |
+
min_lr = 0.1 if step >= warm_up_step else 0.0
|
| 598 |
+
max_lr = 1.0
|
| 599 |
+
normalized_step = step % (warm_up_step + T_step)
|
| 600 |
+
current_max_lr = max_lr * gamma ** (step // (warm_up_step + T_step))
|
| 601 |
+
if current_max_lr < min_lr:
|
| 602 |
+
current_max_lr = min_lr
|
| 603 |
+
if normalized_step < warm_up_step:
|
| 604 |
+
lr_step = min_lr + (normalized_step / warm_up_step) * (current_max_lr - min_lr)
|
| 605 |
+
else:
|
| 606 |
+
step_wc_wp = normalized_step - warm_up_step
|
| 607 |
+
ratio = step_wc_wp / T_step
|
| 608 |
+
lr_step = min_lr + 0.5 * (current_max_lr - min_lr) * (1 + math.cos(math.pi * ratio))
|
| 609 |
+
return lr_step
|
| 610 |
+
|
| 611 |
+
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
| 612 |
+
|
| 613 |
+
lr_scheduler_config = {
|
| 614 |
+
"scheduler": lr_scheduler,
|
| 615 |
+
"interval": "step",
|
| 616 |
+
"frequency": 1,
|
| 617 |
+
"monitor": "val_loss",
|
| 618 |
+
"strict": False,
|
| 619 |
+
"name": None,
|
| 620 |
+
}
|
| 621 |
+
|
| 622 |
+
return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config}
|
hunyuan3d-paintpbr-v2-1/unet/modules.py
ADDED
|
@@ -0,0 +1,1102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
| 2 |
+
# except for the third-party components listed below.
|
| 3 |
+
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
| 4 |
+
# in the repsective licenses of these third-party components.
|
| 5 |
+
# Users must comply with all terms and conditions of original licenses of these third-party
|
| 6 |
+
# components and must ensure that the usage of the third party components adheres to
|
| 7 |
+
# all relevant laws and regulations.
|
| 8 |
+
|
| 9 |
+
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
| 10 |
+
# their software and algorithms, including trained model weights, parameters (including
|
| 11 |
+
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
| 12 |
+
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
| 13 |
+
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import json
|
| 17 |
+
import copy
|
| 18 |
+
import numpy as np
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
from einops import rearrange
|
| 22 |
+
from typing import Any, Callable, Dict, List, Optional, Union, Tuple, Literal
|
| 23 |
+
import diffusers
|
| 24 |
+
from diffusers.utils import deprecate
|
| 25 |
+
from diffusers import (
|
| 26 |
+
DDPMScheduler,
|
| 27 |
+
EulerAncestralDiscreteScheduler,
|
| 28 |
+
UNet2DConditionModel,
|
| 29 |
+
)
|
| 30 |
+
from diffusers.models import UNet2DConditionModel
|
| 31 |
+
from diffusers.models.attention_processor import Attention, AttnProcessor
|
| 32 |
+
from diffusers.models.transformers.transformer_2d import BasicTransformerBlock
|
| 33 |
+
from .attn_processor import SelfAttnProcessor2_0, RefAttnProcessor2_0, PoseRoPEAttnProcessor2_0
|
| 34 |
+
|
| 35 |
+
from transformers import AutoImageProcessor, AutoModel
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class Dino_v2(nn.Module):
|
| 39 |
+
|
| 40 |
+
"""Wrapper for DINOv2 vision transformer (frozen weights).
|
| 41 |
+
|
| 42 |
+
Provides feature extraction for reference images.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
dino_v2_path: Custom path to DINOv2 model weights (uses default if None)
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def __init__(self, dino_v2_path):
|
| 50 |
+
super(Dino_v2, self).__init__()
|
| 51 |
+
self.dino_processor = AutoImageProcessor.from_pretrained(dino_v2_path)
|
| 52 |
+
self.dino_v2 = AutoModel.from_pretrained(dino_v2_path)
|
| 53 |
+
|
| 54 |
+
for param in self.parameters():
|
| 55 |
+
param.requires_grad = False
|
| 56 |
+
|
| 57 |
+
self.dino_v2.eval()
|
| 58 |
+
|
| 59 |
+
def forward(self, images):
|
| 60 |
+
|
| 61 |
+
"""Processes input images through DINOv2 ViT.
|
| 62 |
+
|
| 63 |
+
Handles both tensor input (B, N, C, H, W) and PIL image lists.
|
| 64 |
+
Extracts patch embeddings and flattens spatial dimensions.
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
torch.Tensor: Feature vectors [B, N*(num_patches), feature_dim]
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
if isinstance(images, torch.Tensor):
|
| 71 |
+
batch_size = images.shape[0]
|
| 72 |
+
dino_proceesed_images = self.dino_processor(
|
| 73 |
+
images=rearrange(images, "b n c h w -> (b n) c h w"), return_tensors="pt", do_rescale=False
|
| 74 |
+
).pixel_values
|
| 75 |
+
else:
|
| 76 |
+
batch_size = 1
|
| 77 |
+
dino_proceesed_images = self.dino_processor(images=images, return_tensors="pt").pixel_values
|
| 78 |
+
dino_proceesed_images = torch.stack(
|
| 79 |
+
[torch.from_numpy(np.array(image)) for image in dino_proceesed_images], dim=0
|
| 80 |
+
)
|
| 81 |
+
dino_param = next(self.dino_v2.parameters())
|
| 82 |
+
dino_proceesed_images = dino_proceesed_images.to(dino_param)
|
| 83 |
+
dino_hidden_states = self.dino_v2(dino_proceesed_images)[0]
|
| 84 |
+
dino_hidden_states = rearrange(dino_hidden_states.to(dino_param), "(b n) l c -> b (n l) c", b=batch_size)
|
| 85 |
+
|
| 86 |
+
return dino_hidden_states
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
|
| 90 |
+
# "feed_forward_chunk_size" can be used to save memory
|
| 91 |
+
|
| 92 |
+
"""Memory-efficient feedforward execution via chunking.
|
| 93 |
+
|
| 94 |
+
Divides input along specified dimension for sequential processing.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
ff: Feedforward module to apply
|
| 98 |
+
hidden_states: Input tensor
|
| 99 |
+
chunk_dim: Dimension to split
|
| 100 |
+
chunk_size: Size of each chunk
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
torch.Tensor: Reassembled output tensor
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
if hidden_states.shape[chunk_dim] % chunk_size != 0:
|
| 107 |
+
raise ValueError(
|
| 108 |
+
f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]}"
|
| 109 |
+
f"has to be divisible by chunk size: {chunk_size}."
|
| 110 |
+
"Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
num_chunks = hidden_states.shape[chunk_dim] // chunk_size
|
| 114 |
+
ff_output = torch.cat(
|
| 115 |
+
[ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
|
| 116 |
+
dim=chunk_dim,
|
| 117 |
+
)
|
| 118 |
+
return ff_output
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
@torch.no_grad()
|
| 122 |
+
def compute_voxel_grid_mask(position, grid_resolution=8):
|
| 123 |
+
|
| 124 |
+
"""Generates view-to-view attention mask based on 3D position similarity.
|
| 125 |
+
|
| 126 |
+
Uses voxel grid downsampling to determine spatially adjacent regions.
|
| 127 |
+
Mask indicates where features should interact across different views.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
position: Position maps [B, N, 3, H, W] (normalized 0-1)
|
| 131 |
+
grid_resolution: Spatial reduction factor
|
| 132 |
+
|
| 133 |
+
Returns:
|
| 134 |
+
torch.Tensor: Attention mask [B, N*grid_res**2, N*grid_res**2]
|
| 135 |
+
"""
|
| 136 |
+
|
| 137 |
+
position = position.half()
|
| 138 |
+
B, N, _, H, W = position.shape
|
| 139 |
+
assert H % grid_resolution == 0 and W % grid_resolution == 0
|
| 140 |
+
|
| 141 |
+
valid_mask = (position != 1).all(dim=2, keepdim=True)
|
| 142 |
+
valid_mask = valid_mask.expand_as(position)
|
| 143 |
+
position[valid_mask == False] = 0
|
| 144 |
+
|
| 145 |
+
position = rearrange(
|
| 146 |
+
position,
|
| 147 |
+
"b n c (num_h grid_h) (num_w grid_w) -> b n num_h num_w c grid_h grid_w",
|
| 148 |
+
num_h=grid_resolution,
|
| 149 |
+
num_w=grid_resolution,
|
| 150 |
+
)
|
| 151 |
+
valid_mask = rearrange(
|
| 152 |
+
valid_mask,
|
| 153 |
+
"b n c (num_h grid_h) (num_w grid_w) -> b n num_h num_w c grid_h grid_w",
|
| 154 |
+
num_h=grid_resolution,
|
| 155 |
+
num_w=grid_resolution,
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
grid_position = position.sum(dim=(-2, -1))
|
| 159 |
+
count_masked = valid_mask.sum(dim=(-2, -1))
|
| 160 |
+
|
| 161 |
+
grid_position = grid_position / count_masked.clamp(min=1)
|
| 162 |
+
grid_position[count_masked < 5] = 0
|
| 163 |
+
|
| 164 |
+
grid_position = grid_position.permute(0, 1, 4, 2, 3)
|
| 165 |
+
grid_position = rearrange(grid_position, "b n c h w -> b n (h w) c")
|
| 166 |
+
|
| 167 |
+
grid_position_expanded_1 = grid_position.unsqueeze(2).unsqueeze(4) # 形状变为 B, N, 1, L, 1, 3
|
| 168 |
+
grid_position_expanded_2 = grid_position.unsqueeze(1).unsqueeze(3) # 形状变为 B, 1, N, 1, L, 3
|
| 169 |
+
|
| 170 |
+
# 计算欧氏距离
|
| 171 |
+
distances = torch.norm(grid_position_expanded_1 - grid_position_expanded_2, dim=-1) # 形状为 B, N, N, L, L
|
| 172 |
+
|
| 173 |
+
weights = distances
|
| 174 |
+
grid_distance = 1.73 / grid_resolution
|
| 175 |
+
weights = weights < grid_distance
|
| 176 |
+
|
| 177 |
+
return weights
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def compute_multi_resolution_mask(position_maps, grid_resolutions=[32, 16, 8]):
|
| 181 |
+
|
| 182 |
+
"""Generates attention masks at multiple spatial resolutions.
|
| 183 |
+
|
| 184 |
+
Creates pyramid of position-based masks for hierarchical attention.
|
| 185 |
+
|
| 186 |
+
Args:
|
| 187 |
+
position_maps: Position maps [B, N, 3, H, W]
|
| 188 |
+
grid_resolutions: List of downsampling factors
|
| 189 |
+
|
| 190 |
+
Returns:
|
| 191 |
+
dict: Resolution-specific masks keyed by flattened dimension size
|
| 192 |
+
"""
|
| 193 |
+
|
| 194 |
+
position_attn_mask = {}
|
| 195 |
+
with torch.no_grad():
|
| 196 |
+
for grid_resolution in grid_resolutions:
|
| 197 |
+
position_mask = compute_voxel_grid_mask(position_maps, grid_resolution)
|
| 198 |
+
position_mask = rearrange(position_mask, "b ni nj li lj -> b (ni li) (nj lj)")
|
| 199 |
+
position_attn_mask[position_mask.shape[1]] = position_mask
|
| 200 |
+
return position_attn_mask
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
@torch.no_grad()
|
| 204 |
+
def compute_discrete_voxel_indice(position, grid_resolution=8, voxel_resolution=128):
|
| 205 |
+
|
| 206 |
+
"""Quantizes position maps to discrete voxel indices.
|
| 207 |
+
|
| 208 |
+
Creates sparse 3D coordinate representations for efficient hashing.
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
position: Position maps [B, N, 3, H, W]
|
| 212 |
+
grid_resolution: Spatial downsampling factor
|
| 213 |
+
voxel_resolution: Quantization resolution
|
| 214 |
+
|
| 215 |
+
Returns:
|
| 216 |
+
torch.Tensor: Voxel indices [B, N, grid_res, grid_res, 3]
|
| 217 |
+
"""
|
| 218 |
+
|
| 219 |
+
position = position.half()
|
| 220 |
+
B, N, _, H, W = position.shape
|
| 221 |
+
assert H % grid_resolution == 0 and W % grid_resolution == 0
|
| 222 |
+
|
| 223 |
+
valid_mask = (position != 1).all(dim=2, keepdim=True)
|
| 224 |
+
valid_mask = valid_mask.expand_as(position)
|
| 225 |
+
position[valid_mask == False] = 0
|
| 226 |
+
|
| 227 |
+
position = rearrange(
|
| 228 |
+
position,
|
| 229 |
+
"b n c (num_h grid_h) (num_w grid_w) -> b n num_h num_w c grid_h grid_w",
|
| 230 |
+
num_h=grid_resolution,
|
| 231 |
+
num_w=grid_resolution,
|
| 232 |
+
)
|
| 233 |
+
valid_mask = rearrange(
|
| 234 |
+
valid_mask,
|
| 235 |
+
"b n c (num_h grid_h) (num_w grid_w) -> b n num_h num_w c grid_h grid_w",
|
| 236 |
+
num_h=grid_resolution,
|
| 237 |
+
num_w=grid_resolution,
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
grid_position = position.sum(dim=(-2, -1))
|
| 241 |
+
count_masked = valid_mask.sum(dim=(-2, -1))
|
| 242 |
+
|
| 243 |
+
grid_position = grid_position / count_masked.clamp(min=1)
|
| 244 |
+
voxel_mask_thres = (H // grid_resolution) * (W // grid_resolution) // (4 * 4)
|
| 245 |
+
grid_position[count_masked < voxel_mask_thres] = 0
|
| 246 |
+
|
| 247 |
+
grid_position = grid_position.permute(0, 1, 4, 2, 3).clamp(0, 1) # B N C H W
|
| 248 |
+
voxel_indices = grid_position * (voxel_resolution - 1)
|
| 249 |
+
voxel_indices = torch.round(voxel_indices).long()
|
| 250 |
+
return voxel_indices
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def calc_multires_voxel_idxs(position_maps, grid_resolutions=[64, 32, 16, 8], voxel_resolutions=[512, 256, 128, 64]):
|
| 254 |
+
|
| 255 |
+
"""Generates multi-resolution voxel indices for position encoding.
|
| 256 |
+
|
| 257 |
+
Creates pyramid of quantized position representations.
|
| 258 |
+
|
| 259 |
+
Args:
|
| 260 |
+
position_maps: Input position maps
|
| 261 |
+
grid_resolutions: Spatial resolution levels
|
| 262 |
+
voxel_resolutions: Quantization levels
|
| 263 |
+
|
| 264 |
+
Returns:
|
| 265 |
+
dict: Voxel indices keyed by flattened dimension size, with resolution metadata
|
| 266 |
+
"""
|
| 267 |
+
|
| 268 |
+
voxel_indices = {}
|
| 269 |
+
with torch.no_grad():
|
| 270 |
+
for grid_resolution, voxel_resolution in zip(grid_resolutions, voxel_resolutions):
|
| 271 |
+
voxel_indice = compute_discrete_voxel_indice(position_maps, grid_resolution, voxel_resolution)
|
| 272 |
+
voxel_indice = rearrange(voxel_indice, "b n c h w -> b (n h w) c")
|
| 273 |
+
voxel_indices[voxel_indice.shape[1]] = {"voxel_indices": voxel_indice, "voxel_resolution": voxel_resolution}
|
| 274 |
+
return voxel_indices
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
class Basic2p5DTransformerBlock(torch.nn.Module):
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
"""Enhanced transformer block for multiview 2.5D image generation.
|
| 281 |
+
|
| 282 |
+
Extends standard transformer blocks with:
|
| 283 |
+
- Material-specific attention (MDA)
|
| 284 |
+
- Multiview attention (MA)
|
| 285 |
+
- Reference attention (RA)
|
| 286 |
+
- DINO feature integration
|
| 287 |
+
|
| 288 |
+
Args:
|
| 289 |
+
transformer: Base transformer block
|
| 290 |
+
layer_name: Identifier for layer
|
| 291 |
+
use_ma: Enable multiview attention
|
| 292 |
+
use_ra: Enable reference attention
|
| 293 |
+
use_mda: Enable material-aware attention
|
| 294 |
+
use_dino: Enable DINO feature integration
|
| 295 |
+
pbr_setting: List of PBR materials
|
| 296 |
+
"""
|
| 297 |
+
|
| 298 |
+
def __init__(
|
| 299 |
+
self,
|
| 300 |
+
transformer: BasicTransformerBlock,
|
| 301 |
+
layer_name,
|
| 302 |
+
use_ma=True,
|
| 303 |
+
use_ra=True,
|
| 304 |
+
use_mda=True,
|
| 305 |
+
use_dino=True,
|
| 306 |
+
pbr_setting=None,
|
| 307 |
+
) -> None:
|
| 308 |
+
|
| 309 |
+
"""
|
| 310 |
+
Initialization:
|
| 311 |
+
1. Material-Dimension Attention (MDA):
|
| 312 |
+
- Processes each PBR material with separate projection weights
|
| 313 |
+
- Uses custom SelfAttnProcessor2_0 with material awareness
|
| 314 |
+
|
| 315 |
+
2. Multiview Attention (MA):
|
| 316 |
+
- Adds cross-view attention with PoseRoPE
|
| 317 |
+
- Initialized as zero-initialized residual pathway
|
| 318 |
+
|
| 319 |
+
3. Reference Attention (RA):
|
| 320 |
+
- Conditions on reference view features
|
| 321 |
+
- Uses RefAttnProcessor2_0 for material-specific conditioning
|
| 322 |
+
|
| 323 |
+
4. DINO Attention:
|
| 324 |
+
- Incorporates DINO-ViT features
|
| 325 |
+
- Initialized as zero-initialized residual pathway
|
| 326 |
+
"""
|
| 327 |
+
|
| 328 |
+
super().__init__()
|
| 329 |
+
self.transformer = transformer
|
| 330 |
+
self.layer_name = layer_name
|
| 331 |
+
self.use_ma = use_ma
|
| 332 |
+
self.use_ra = use_ra
|
| 333 |
+
self.use_mda = use_mda
|
| 334 |
+
self.use_dino = use_dino
|
| 335 |
+
self.pbr_setting = pbr_setting
|
| 336 |
+
|
| 337 |
+
if self.use_mda:
|
| 338 |
+
self.attn1.set_processor(
|
| 339 |
+
SelfAttnProcessor2_0(
|
| 340 |
+
query_dim=self.dim,
|
| 341 |
+
heads=self.num_attention_heads,
|
| 342 |
+
dim_head=self.attention_head_dim,
|
| 343 |
+
dropout=self.dropout,
|
| 344 |
+
bias=self.attention_bias,
|
| 345 |
+
cross_attention_dim=None,
|
| 346 |
+
upcast_attention=self.attn1.upcast_attention,
|
| 347 |
+
out_bias=True,
|
| 348 |
+
pbr_setting=self.pbr_setting,
|
| 349 |
+
)
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
# multiview attn
|
| 353 |
+
if self.use_ma:
|
| 354 |
+
self.attn_multiview = Attention(
|
| 355 |
+
query_dim=self.dim,
|
| 356 |
+
heads=self.num_attention_heads,
|
| 357 |
+
dim_head=self.attention_head_dim,
|
| 358 |
+
dropout=self.dropout,
|
| 359 |
+
bias=self.attention_bias,
|
| 360 |
+
cross_attention_dim=None,
|
| 361 |
+
upcast_attention=self.attn1.upcast_attention,
|
| 362 |
+
out_bias=True,
|
| 363 |
+
processor=PoseRoPEAttnProcessor2_0(),
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
# ref attn
|
| 367 |
+
if self.use_ra:
|
| 368 |
+
self.attn_refview = Attention(
|
| 369 |
+
query_dim=self.dim,
|
| 370 |
+
heads=self.num_attention_heads,
|
| 371 |
+
dim_head=self.attention_head_dim,
|
| 372 |
+
dropout=self.dropout,
|
| 373 |
+
bias=self.attention_bias,
|
| 374 |
+
cross_attention_dim=None,
|
| 375 |
+
upcast_attention=self.attn1.upcast_attention,
|
| 376 |
+
out_bias=True,
|
| 377 |
+
processor=RefAttnProcessor2_0(
|
| 378 |
+
query_dim=self.dim,
|
| 379 |
+
heads=self.num_attention_heads,
|
| 380 |
+
dim_head=self.attention_head_dim,
|
| 381 |
+
dropout=self.dropout,
|
| 382 |
+
bias=self.attention_bias,
|
| 383 |
+
cross_attention_dim=None,
|
| 384 |
+
upcast_attention=self.attn1.upcast_attention,
|
| 385 |
+
out_bias=True,
|
| 386 |
+
pbr_setting=self.pbr_setting,
|
| 387 |
+
),
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
# dino attn
|
| 391 |
+
if self.use_dino:
|
| 392 |
+
self.attn_dino = Attention(
|
| 393 |
+
query_dim=self.dim,
|
| 394 |
+
heads=self.num_attention_heads,
|
| 395 |
+
dim_head=self.attention_head_dim,
|
| 396 |
+
dropout=self.dropout,
|
| 397 |
+
bias=self.attention_bias,
|
| 398 |
+
cross_attention_dim=self.cross_attention_dim,
|
| 399 |
+
upcast_attention=self.attn2.upcast_attention,
|
| 400 |
+
out_bias=True,
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
self._initialize_attn_weights()
|
| 404 |
+
|
| 405 |
+
def _initialize_attn_weights(self):
|
| 406 |
+
|
| 407 |
+
"""Initializes specialized attention heads with base weights.
|
| 408 |
+
|
| 409 |
+
Uses weight sharing strategy:
|
| 410 |
+
- Copies base transformer weights to specialized heads
|
| 411 |
+
- Initializes newly-added parameters to zero
|
| 412 |
+
"""
|
| 413 |
+
|
| 414 |
+
if self.use_mda:
|
| 415 |
+
for token in self.pbr_setting:
|
| 416 |
+
if token == "albedo":
|
| 417 |
+
continue
|
| 418 |
+
getattr(self.attn1.processor, f"to_q_{token}").load_state_dict(self.attn1.to_q.state_dict())
|
| 419 |
+
getattr(self.attn1.processor, f"to_k_{token}").load_state_dict(self.attn1.to_k.state_dict())
|
| 420 |
+
getattr(self.attn1.processor, f"to_v_{token}").load_state_dict(self.attn1.to_v.state_dict())
|
| 421 |
+
getattr(self.attn1.processor, f"to_out_{token}").load_state_dict(self.attn1.to_out.state_dict())
|
| 422 |
+
|
| 423 |
+
if self.use_ma:
|
| 424 |
+
self.attn_multiview.load_state_dict(self.attn1.state_dict(), strict=False)
|
| 425 |
+
with torch.no_grad():
|
| 426 |
+
for layer in self.attn_multiview.to_out:
|
| 427 |
+
for param in layer.parameters():
|
| 428 |
+
param.zero_()
|
| 429 |
+
|
| 430 |
+
if self.use_ra:
|
| 431 |
+
self.attn_refview.load_state_dict(self.attn1.state_dict(), strict=False)
|
| 432 |
+
for token in self.pbr_setting:
|
| 433 |
+
if token == "albedo":
|
| 434 |
+
continue
|
| 435 |
+
getattr(self.attn_refview.processor, f"to_v_{token}").load_state_dict(
|
| 436 |
+
self.attn_refview.to_q.state_dict()
|
| 437 |
+
)
|
| 438 |
+
getattr(self.attn_refview.processor, f"to_out_{token}").load_state_dict(
|
| 439 |
+
self.attn_refview.to_out.state_dict()
|
| 440 |
+
)
|
| 441 |
+
with torch.no_grad():
|
| 442 |
+
for layer in self.attn_refview.to_out:
|
| 443 |
+
for param in layer.parameters():
|
| 444 |
+
param.zero_()
|
| 445 |
+
for token in self.pbr_setting:
|
| 446 |
+
if token == "albedo":
|
| 447 |
+
continue
|
| 448 |
+
for layer in getattr(self.attn_refview.processor, f"to_out_{token}"):
|
| 449 |
+
for param in layer.parameters():
|
| 450 |
+
param.zero_()
|
| 451 |
+
|
| 452 |
+
if self.use_dino:
|
| 453 |
+
self.attn_dino.load_state_dict(self.attn2.state_dict(), strict=False)
|
| 454 |
+
with torch.no_grad():
|
| 455 |
+
for layer in self.attn_dino.to_out:
|
| 456 |
+
for param in layer.parameters():
|
| 457 |
+
param.zero_()
|
| 458 |
+
|
| 459 |
+
if self.use_dino:
|
| 460 |
+
self.attn_dino.load_state_dict(self.attn2.state_dict(), strict=False)
|
| 461 |
+
with torch.no_grad():
|
| 462 |
+
for layer in self.attn_dino.to_out:
|
| 463 |
+
for param in layer.parameters():
|
| 464 |
+
param.zero_()
|
| 465 |
+
|
| 466 |
+
def __getattr__(self, name: str):
|
| 467 |
+
try:
|
| 468 |
+
return super().__getattr__(name)
|
| 469 |
+
except AttributeError:
|
| 470 |
+
return getattr(self.transformer, name)
|
| 471 |
+
|
| 472 |
+
def forward(
|
| 473 |
+
self,
|
| 474 |
+
hidden_states: torch.Tensor,
|
| 475 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 476 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 477 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 478 |
+
timestep: Optional[torch.LongTensor] = None,
|
| 479 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
| 480 |
+
class_labels: Optional[torch.LongTensor] = None,
|
| 481 |
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
| 482 |
+
) -> torch.Tensor:
|
| 483 |
+
|
| 484 |
+
"""Forward pass with multi-mechanism attention.
|
| 485 |
+
|
| 486 |
+
Processing stages:
|
| 487 |
+
1. Material-aware self-attention (MDA)
|
| 488 |
+
2. Reference attention (RA)
|
| 489 |
+
3. Multiview attention (MA) with position-aware attention
|
| 490 |
+
4. Text conditioning (base attention)
|
| 491 |
+
5. DINO feature conditioning (optional)
|
| 492 |
+
6. Position-aware conditioning
|
| 493 |
+
7. Feed-forward network
|
| 494 |
+
|
| 495 |
+
Args:
|
| 496 |
+
hidden_states: Input features [B * N_materials * N_views, Seq_len, Feat_dim]
|
| 497 |
+
See base transformer for other parameters
|
| 498 |
+
|
| 499 |
+
Returns:
|
| 500 |
+
torch.Tensor: Output features
|
| 501 |
+
"""
|
| 502 |
+
# [Full multi-mechanism processing pipeline...]
|
| 503 |
+
# Key processing stages:
|
| 504 |
+
# 1. Material-aware self-attention (handles albedo/mr separation)
|
| 505 |
+
# 2. Reference attention (conditioned on reference features)
|
| 506 |
+
# 3. View-to-view attention with geometric constraints
|
| 507 |
+
# 4. Text-to-image cross-attention
|
| 508 |
+
# 5. DINO feature fusion (when enabled)
|
| 509 |
+
# 6. Positional conditioning (RoPE-style)
|
| 510 |
+
# 7. Feed-forward network with conditional normalization
|
| 511 |
+
|
| 512 |
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
| 513 |
+
# 0. Self-Attention
|
| 514 |
+
batch_size = hidden_states.shape[0]
|
| 515 |
+
|
| 516 |
+
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
|
| 517 |
+
num_in_batch = cross_attention_kwargs.pop("num_in_batch", 1)
|
| 518 |
+
mode = cross_attention_kwargs.pop("mode", None)
|
| 519 |
+
mva_scale = cross_attention_kwargs.pop("mva_scale", 1.0)
|
| 520 |
+
ref_scale = cross_attention_kwargs.pop("ref_scale", 1.0)
|
| 521 |
+
condition_embed_dict = cross_attention_kwargs.pop("condition_embed_dict", None)
|
| 522 |
+
dino_hidden_states = cross_attention_kwargs.pop("dino_hidden_states", None)
|
| 523 |
+
position_voxel_indices = cross_attention_kwargs.pop("position_voxel_indices", None)
|
| 524 |
+
N_pbr = len(self.pbr_setting) if self.pbr_setting is not None else 1
|
| 525 |
+
|
| 526 |
+
if self.norm_type == "ada_norm":
|
| 527 |
+
norm_hidden_states = self.norm1(hidden_states, timestep)
|
| 528 |
+
elif self.norm_type == "ada_norm_zero":
|
| 529 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
| 530 |
+
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
|
| 531 |
+
)
|
| 532 |
+
elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]:
|
| 533 |
+
norm_hidden_states = self.norm1(hidden_states)
|
| 534 |
+
elif self.norm_type == "ada_norm_continuous":
|
| 535 |
+
norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
|
| 536 |
+
elif self.norm_type == "ada_norm_single":
|
| 537 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
| 538 |
+
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
|
| 539 |
+
).chunk(6, dim=1)
|
| 540 |
+
norm_hidden_states = self.norm1(hidden_states)
|
| 541 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
| 542 |
+
else:
|
| 543 |
+
raise ValueError("Incorrect norm used")
|
| 544 |
+
|
| 545 |
+
if self.pos_embed is not None:
|
| 546 |
+
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
| 547 |
+
|
| 548 |
+
# 1. Prepare GLIGEN inputs
|
| 549 |
+
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
|
| 550 |
+
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
|
| 551 |
+
|
| 552 |
+
if self.use_mda:
|
| 553 |
+
mda_norm_hidden_states = rearrange(
|
| 554 |
+
norm_hidden_states, "(b n_pbr n) l c -> b n_pbr n l c", n=num_in_batch, n_pbr=N_pbr
|
| 555 |
+
)
|
| 556 |
+
attn_output = self.attn1(
|
| 557 |
+
mda_norm_hidden_states,
|
| 558 |
+
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
| 559 |
+
attention_mask=attention_mask,
|
| 560 |
+
**cross_attention_kwargs,
|
| 561 |
+
)
|
| 562 |
+
attn_output = rearrange(attn_output, "b n_pbr n l c -> (b n_pbr n) l c")
|
| 563 |
+
else:
|
| 564 |
+
attn_output = self.attn1(
|
| 565 |
+
norm_hidden_states,
|
| 566 |
+
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
| 567 |
+
attention_mask=attention_mask,
|
| 568 |
+
**cross_attention_kwargs,
|
| 569 |
+
)
|
| 570 |
+
|
| 571 |
+
if self.norm_type == "ada_norm_zero":
|
| 572 |
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
| 573 |
+
elif self.norm_type == "ada_norm_single":
|
| 574 |
+
attn_output = gate_msa * attn_output
|
| 575 |
+
|
| 576 |
+
hidden_states = attn_output + hidden_states
|
| 577 |
+
if hidden_states.ndim == 4:
|
| 578 |
+
hidden_states = hidden_states.squeeze(1)
|
| 579 |
+
|
| 580 |
+
# 1.2 Reference Attention
|
| 581 |
+
if "w" in mode:
|
| 582 |
+
condition_embed_dict[self.layer_name] = rearrange(
|
| 583 |
+
norm_hidden_states, "(b n) l c -> b (n l) c", n=num_in_batch
|
| 584 |
+
) # B, (N L), C
|
| 585 |
+
|
| 586 |
+
if "r" in mode and self.use_ra:
|
| 587 |
+
condition_embed = condition_embed_dict[self.layer_name]
|
| 588 |
+
|
| 589 |
+
#! Only using albedo features for reference attention
|
| 590 |
+
ref_norm_hidden_states = rearrange(
|
| 591 |
+
norm_hidden_states, "(b n_pbr n) l c -> b n_pbr (n l) c", n=num_in_batch, n_pbr=N_pbr
|
| 592 |
+
)[:, 0, ...]
|
| 593 |
+
|
| 594 |
+
attn_output = self.attn_refview(
|
| 595 |
+
ref_norm_hidden_states,
|
| 596 |
+
encoder_hidden_states=condition_embed,
|
| 597 |
+
attention_mask=None,
|
| 598 |
+
**cross_attention_kwargs,
|
| 599 |
+
) # b (n l) c
|
| 600 |
+
attn_output = rearrange(attn_output, "b n_pbr (n l) c -> (b n_pbr n) l c", n=num_in_batch, n_pbr=N_pbr)
|
| 601 |
+
|
| 602 |
+
ref_scale_timing = ref_scale
|
| 603 |
+
if isinstance(ref_scale, torch.Tensor):
|
| 604 |
+
ref_scale_timing = ref_scale.unsqueeze(1).repeat(1, num_in_batch * N_pbr).view(-1)
|
| 605 |
+
for _ in range(attn_output.ndim - 1):
|
| 606 |
+
ref_scale_timing = ref_scale_timing.unsqueeze(-1)
|
| 607 |
+
hidden_states = ref_scale_timing * attn_output + hidden_states
|
| 608 |
+
if hidden_states.ndim == 4:
|
| 609 |
+
hidden_states = hidden_states.squeeze(1)
|
| 610 |
+
|
| 611 |
+
# 1.3 Multiview Attention
|
| 612 |
+
if num_in_batch > 1 and self.use_ma:
|
| 613 |
+
multivew_hidden_states = rearrange(
|
| 614 |
+
norm_hidden_states, "(b n_pbr n) l c -> (b n_pbr) (n l) c", n_pbr=N_pbr, n=num_in_batch
|
| 615 |
+
)
|
| 616 |
+
position_indices = None
|
| 617 |
+
if position_voxel_indices is not None:
|
| 618 |
+
if multivew_hidden_states.shape[1] in position_voxel_indices:
|
| 619 |
+
position_indices = position_voxel_indices[multivew_hidden_states.shape[1]]
|
| 620 |
+
|
| 621 |
+
attn_output = self.attn_multiview(
|
| 622 |
+
multivew_hidden_states,
|
| 623 |
+
encoder_hidden_states=multivew_hidden_states,
|
| 624 |
+
position_indices=position_indices,
|
| 625 |
+
n_pbrs=N_pbr,
|
| 626 |
+
**cross_attention_kwargs,
|
| 627 |
+
)
|
| 628 |
+
|
| 629 |
+
attn_output = rearrange(attn_output, "(b n_pbr) (n l) c -> (b n_pbr n) l c", n_pbr=N_pbr, n=num_in_batch)
|
| 630 |
+
|
| 631 |
+
hidden_states = mva_scale * attn_output + hidden_states
|
| 632 |
+
if hidden_states.ndim == 4:
|
| 633 |
+
hidden_states = hidden_states.squeeze(1)
|
| 634 |
+
|
| 635 |
+
# 1.2 GLIGEN Control
|
| 636 |
+
if gligen_kwargs is not None:
|
| 637 |
+
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
|
| 638 |
+
|
| 639 |
+
# 3. Cross-Attention
|
| 640 |
+
if self.attn2 is not None:
|
| 641 |
+
if self.norm_type == "ada_norm":
|
| 642 |
+
norm_hidden_states = self.norm2(hidden_states, timestep)
|
| 643 |
+
elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]:
|
| 644 |
+
norm_hidden_states = self.norm2(hidden_states)
|
| 645 |
+
elif self.norm_type == "ada_norm_single":
|
| 646 |
+
# For PixArt norm2 isn't applied here:
|
| 647 |
+
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
|
| 648 |
+
norm_hidden_states = hidden_states
|
| 649 |
+
elif self.norm_type == "ada_norm_continuous":
|
| 650 |
+
norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
|
| 651 |
+
else:
|
| 652 |
+
raise ValueError("Incorrect norm")
|
| 653 |
+
|
| 654 |
+
if self.pos_embed is not None and self.norm_type != "ada_norm_single":
|
| 655 |
+
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
| 656 |
+
|
| 657 |
+
attn_output = self.attn2(
|
| 658 |
+
norm_hidden_states,
|
| 659 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 660 |
+
attention_mask=encoder_attention_mask,
|
| 661 |
+
**cross_attention_kwargs,
|
| 662 |
+
)
|
| 663 |
+
hidden_states = attn_output + hidden_states
|
| 664 |
+
|
| 665 |
+
# dino attn
|
| 666 |
+
if self.use_dino:
|
| 667 |
+
dino_hidden_states = dino_hidden_states.unsqueeze(1).repeat(1, N_pbr * num_in_batch, 1, 1)
|
| 668 |
+
dino_hidden_states = rearrange(dino_hidden_states, "b n l c -> (b n) l c")
|
| 669 |
+
attn_output = self.attn_dino(
|
| 670 |
+
norm_hidden_states,
|
| 671 |
+
encoder_hidden_states=dino_hidden_states,
|
| 672 |
+
attention_mask=None,
|
| 673 |
+
**cross_attention_kwargs,
|
| 674 |
+
)
|
| 675 |
+
|
| 676 |
+
hidden_states = attn_output + hidden_states
|
| 677 |
+
|
| 678 |
+
# 4. Feed-forward
|
| 679 |
+
# i2vgen doesn't have this norm 🤷♂️
|
| 680 |
+
if self.norm_type == "ada_norm_continuous":
|
| 681 |
+
norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
|
| 682 |
+
elif not self.norm_type == "ada_norm_single":
|
| 683 |
+
norm_hidden_states = self.norm3(hidden_states)
|
| 684 |
+
|
| 685 |
+
if self.norm_type == "ada_norm_zero":
|
| 686 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
| 687 |
+
|
| 688 |
+
if self.norm_type == "ada_norm_single":
|
| 689 |
+
norm_hidden_states = self.norm2(hidden_states)
|
| 690 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
|
| 691 |
+
|
| 692 |
+
if self._chunk_size is not None:
|
| 693 |
+
# "feed_forward_chunk_size" can be used to save memory
|
| 694 |
+
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
|
| 695 |
+
else:
|
| 696 |
+
ff_output = self.ff(norm_hidden_states)
|
| 697 |
+
|
| 698 |
+
if self.norm_type == "ada_norm_zero":
|
| 699 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
| 700 |
+
elif self.norm_type == "ada_norm_single":
|
| 701 |
+
ff_output = gate_mlp * ff_output
|
| 702 |
+
|
| 703 |
+
hidden_states = ff_output + hidden_states
|
| 704 |
+
if hidden_states.ndim == 4:
|
| 705 |
+
hidden_states = hidden_states.squeeze(1)
|
| 706 |
+
|
| 707 |
+
return hidden_states
|
| 708 |
+
|
| 709 |
+
|
| 710 |
+
class ImageProjModel(torch.nn.Module):
|
| 711 |
+
|
| 712 |
+
"""Projects image embeddings into cross-attention space.
|
| 713 |
+
|
| 714 |
+
Transforms CLIP embeddings into additional context tokens for conditioning.
|
| 715 |
+
|
| 716 |
+
Args:
|
| 717 |
+
cross_attention_dim: Dimension of attention space
|
| 718 |
+
clip_embeddings_dim: Dimension of input CLIP embeddings
|
| 719 |
+
clip_extra_context_tokens: Number of context tokens to generate
|
| 720 |
+
"""
|
| 721 |
+
|
| 722 |
+
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
|
| 723 |
+
super().__init__()
|
| 724 |
+
|
| 725 |
+
self.generator = None
|
| 726 |
+
self.cross_attention_dim = cross_attention_dim
|
| 727 |
+
self.clip_extra_context_tokens = clip_extra_context_tokens
|
| 728 |
+
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
|
| 729 |
+
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
| 730 |
+
|
| 731 |
+
def forward(self, image_embeds):
|
| 732 |
+
|
| 733 |
+
"""Projects image embeddings to cross-attention context tokens.
|
| 734 |
+
|
| 735 |
+
Args:
|
| 736 |
+
image_embeds: Input embeddings [B, N, C] or [B, C]
|
| 737 |
+
|
| 738 |
+
Returns:
|
| 739 |
+
torch.Tensor: Context tokens [B, N*clip_extra_context_tokens, cross_attention_dim]
|
| 740 |
+
"""
|
| 741 |
+
|
| 742 |
+
embeds = image_embeds
|
| 743 |
+
num_token = 1
|
| 744 |
+
if embeds.dim() == 3:
|
| 745 |
+
num_token = embeds.shape[1]
|
| 746 |
+
embeds = rearrange(embeds, "b n c -> (b n) c")
|
| 747 |
+
|
| 748 |
+
clip_extra_context_tokens = self.proj(embeds).reshape(
|
| 749 |
+
-1, self.clip_extra_context_tokens, self.cross_attention_dim
|
| 750 |
+
)
|
| 751 |
+
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
|
| 752 |
+
|
| 753 |
+
clip_extra_context_tokens = rearrange(clip_extra_context_tokens, "(b nt) n c -> b (nt n) c", nt=num_token)
|
| 754 |
+
|
| 755 |
+
return clip_extra_context_tokens
|
| 756 |
+
|
| 757 |
+
|
| 758 |
+
class UNet2p5DConditionModel(torch.nn.Module):
|
| 759 |
+
|
| 760 |
+
"""2.5D UNet extension for multiview PBR generation.
|
| 761 |
+
|
| 762 |
+
Enhances standard 2D UNet with:
|
| 763 |
+
- Multiview attention mechanisms
|
| 764 |
+
- Material-aware processing
|
| 765 |
+
- Position-aware conditioning
|
| 766 |
+
- Dual-stream reference processing
|
| 767 |
+
|
| 768 |
+
Args:
|
| 769 |
+
unet: Base 2D UNet model
|
| 770 |
+
train_sched: Training scheduler (DDPM)
|
| 771 |
+
val_sched: Validation scheduler (EulerAncestral)
|
| 772 |
+
"""
|
| 773 |
+
|
| 774 |
+
def __init__(
|
| 775 |
+
self,
|
| 776 |
+
unet: UNet2DConditionModel,
|
| 777 |
+
train_sched: DDPMScheduler = None,
|
| 778 |
+
val_sched: EulerAncestralDiscreteScheduler = None,
|
| 779 |
+
) -> None:
|
| 780 |
+
super().__init__()
|
| 781 |
+
self.unet = unet
|
| 782 |
+
self.train_sched = train_sched
|
| 783 |
+
self.val_sched = val_sched
|
| 784 |
+
|
| 785 |
+
self.use_ma = True
|
| 786 |
+
self.use_ra = True
|
| 787 |
+
self.use_mda = True
|
| 788 |
+
self.use_dino = True
|
| 789 |
+
self.use_position_rope = True
|
| 790 |
+
self.use_learned_text_clip = True
|
| 791 |
+
self.use_dual_stream = True
|
| 792 |
+
self.pbr_setting = ["albedo", "mr"]
|
| 793 |
+
self.pbr_token_channels = 77
|
| 794 |
+
|
| 795 |
+
if self.use_dual_stream and self.use_ra:
|
| 796 |
+
self.unet_dual = copy.deepcopy(unet)
|
| 797 |
+
self.init_attention(self.unet_dual)
|
| 798 |
+
|
| 799 |
+
self.init_attention(
|
| 800 |
+
self.unet,
|
| 801 |
+
use_ma=self.use_ma,
|
| 802 |
+
use_ra=self.use_ra,
|
| 803 |
+
use_dino=self.use_dino,
|
| 804 |
+
use_mda=self.use_mda,
|
| 805 |
+
pbr_setting=self.pbr_setting,
|
| 806 |
+
)
|
| 807 |
+
self.init_condition(use_dino=self.use_dino)
|
| 808 |
+
|
| 809 |
+
@staticmethod
|
| 810 |
+
def from_pretrained(pretrained_model_name_or_path, **kwargs):
|
| 811 |
+
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
|
| 812 |
+
config_path = os.path.join(pretrained_model_name_or_path, "config.json")
|
| 813 |
+
unet_ckpt_path = os.path.join(pretrained_model_name_or_path, "diffusion_pytorch_model.bin")
|
| 814 |
+
with open(config_path, "r", encoding="utf-8") as file:
|
| 815 |
+
config = json.load(file)
|
| 816 |
+
unet = UNet2DConditionModel(**config)
|
| 817 |
+
unet_2p5d = UNet2p5DConditionModel(unet)
|
| 818 |
+
unet_2p5d.unet.conv_in = torch.nn.Conv2d(
|
| 819 |
+
12,
|
| 820 |
+
unet.conv_in.out_channels,
|
| 821 |
+
kernel_size=unet.conv_in.kernel_size,
|
| 822 |
+
stride=unet.conv_in.stride,
|
| 823 |
+
padding=unet.conv_in.padding,
|
| 824 |
+
dilation=unet.conv_in.dilation,
|
| 825 |
+
groups=unet.conv_in.groups,
|
| 826 |
+
bias=unet.conv_in.bias is not None,
|
| 827 |
+
)
|
| 828 |
+
unet_ckpt = torch.load(unet_ckpt_path, map_location="cpu", weights_only=True)
|
| 829 |
+
unet_2p5d.load_state_dict(unet_ckpt, strict=True)
|
| 830 |
+
unet_2p5d = unet_2p5d.to(torch_dtype)
|
| 831 |
+
return unet_2p5d
|
| 832 |
+
|
| 833 |
+
def init_condition(self, use_dino):
|
| 834 |
+
|
| 835 |
+
"""Initializes conditioning mechanisms for multiview PBR generation.
|
| 836 |
+
|
| 837 |
+
Sets up:
|
| 838 |
+
1. Learned text embeddings: Material-specific tokens (albedo, mr) initialized to zeros
|
| 839 |
+
2. DINO projector: Model to process DINO-ViT features for cross-attention
|
| 840 |
+
|
| 841 |
+
Args:
|
| 842 |
+
use_dino: Flag to enable DINO feature integration
|
| 843 |
+
"""
|
| 844 |
+
|
| 845 |
+
if self.use_learned_text_clip:
|
| 846 |
+
for token in self.pbr_setting:
|
| 847 |
+
self.unet.register_parameter(
|
| 848 |
+
f"learned_text_clip_{token}", nn.Parameter(torch.zeros(self.pbr_token_channels, 1024))
|
| 849 |
+
)
|
| 850 |
+
self.unet.learned_text_clip_ref = nn.Parameter(torch.zeros(self.pbr_token_channels, 1024))
|
| 851 |
+
|
| 852 |
+
if use_dino:
|
| 853 |
+
self.unet.image_proj_model_dino = ImageProjModel(
|
| 854 |
+
cross_attention_dim=self.unet.config.cross_attention_dim,
|
| 855 |
+
clip_embeddings_dim=1536,
|
| 856 |
+
clip_extra_context_tokens=4,
|
| 857 |
+
)
|
| 858 |
+
|
| 859 |
+
def init_attention(self, unet, use_ma=False, use_ra=False, use_mda=False, use_dino=False, pbr_setting=None):
|
| 860 |
+
|
| 861 |
+
"""Recursively replaces standard transformers with enhanced 2.5D blocks.
|
| 862 |
+
|
| 863 |
+
Processes UNet architecture:
|
| 864 |
+
1. Downsampling blocks: Replaces transformers in attention layers
|
| 865 |
+
2. Middle block: Upgrades central transformers
|
| 866 |
+
3. Upsampling blocks: Modifies decoder transformers
|
| 867 |
+
|
| 868 |
+
Args:
|
| 869 |
+
unet: UNet model to enhance
|
| 870 |
+
use_ma: Enable multiview attention
|
| 871 |
+
use_ra: Enable reference attention
|
| 872 |
+
use_mda: Enable material-specific attention
|
| 873 |
+
use_dino: Enable DINO feature integration
|
| 874 |
+
pbr_setting: List of PBR materials
|
| 875 |
+
"""
|
| 876 |
+
|
| 877 |
+
for down_block_i, down_block in enumerate(unet.down_blocks):
|
| 878 |
+
if hasattr(down_block, "has_cross_attention") and down_block.has_cross_attention:
|
| 879 |
+
for attn_i, attn in enumerate(down_block.attentions):
|
| 880 |
+
for transformer_i, transformer in enumerate(attn.transformer_blocks):
|
| 881 |
+
if isinstance(transformer, BasicTransformerBlock):
|
| 882 |
+
attn.transformer_blocks[transformer_i] = Basic2p5DTransformerBlock(
|
| 883 |
+
transformer,
|
| 884 |
+
f"down_{down_block_i}_{attn_i}_{transformer_i}",
|
| 885 |
+
use_ma,
|
| 886 |
+
use_ra,
|
| 887 |
+
use_mda,
|
| 888 |
+
use_dino,
|
| 889 |
+
pbr_setting,
|
| 890 |
+
)
|
| 891 |
+
|
| 892 |
+
if hasattr(unet.mid_block, "has_cross_attention") and unet.mid_block.has_cross_attention:
|
| 893 |
+
for attn_i, attn in enumerate(unet.mid_block.attentions):
|
| 894 |
+
for transformer_i, transformer in enumerate(attn.transformer_blocks):
|
| 895 |
+
if isinstance(transformer, BasicTransformerBlock):
|
| 896 |
+
attn.transformer_blocks[transformer_i] = Basic2p5DTransformerBlock(
|
| 897 |
+
transformer, f"mid_{attn_i}_{transformer_i}", use_ma, use_ra, use_mda, use_dino, pbr_setting
|
| 898 |
+
)
|
| 899 |
+
|
| 900 |
+
for up_block_i, up_block in enumerate(unet.up_blocks):
|
| 901 |
+
if hasattr(up_block, "has_cross_attention") and up_block.has_cross_attention:
|
| 902 |
+
for attn_i, attn in enumerate(up_block.attentions):
|
| 903 |
+
for transformer_i, transformer in enumerate(attn.transformer_blocks):
|
| 904 |
+
if isinstance(transformer, BasicTransformerBlock):
|
| 905 |
+
attn.transformer_blocks[transformer_i] = Basic2p5DTransformerBlock(
|
| 906 |
+
transformer,
|
| 907 |
+
f"up_{up_block_i}_{attn_i}_{transformer_i}",
|
| 908 |
+
use_ma,
|
| 909 |
+
use_ra,
|
| 910 |
+
use_mda,
|
| 911 |
+
use_dino,
|
| 912 |
+
pbr_setting,
|
| 913 |
+
)
|
| 914 |
+
|
| 915 |
+
def __getattr__(self, name: str):
|
| 916 |
+
try:
|
| 917 |
+
return super().__getattr__(name)
|
| 918 |
+
except AttributeError:
|
| 919 |
+
return getattr(self.unet, name)
|
| 920 |
+
|
| 921 |
+
def forward(
|
| 922 |
+
self,
|
| 923 |
+
sample,
|
| 924 |
+
timestep,
|
| 925 |
+
encoder_hidden_states,
|
| 926 |
+
*args,
|
| 927 |
+
added_cond_kwargs=None,
|
| 928 |
+
cross_attention_kwargs=None,
|
| 929 |
+
down_intrablock_additional_residuals=None,
|
| 930 |
+
down_block_res_samples=None,
|
| 931 |
+
mid_block_res_sample=None,
|
| 932 |
+
**cached_condition,
|
| 933 |
+
):
|
| 934 |
+
|
| 935 |
+
"""Forward pass with multiview/material conditioning.
|
| 936 |
+
|
| 937 |
+
Key stages:
|
| 938 |
+
1. Input preparation (concat normal/position maps)
|
| 939 |
+
2. Reference feature extraction (dual-stream)
|
| 940 |
+
3. Position encoding (voxel indices)
|
| 941 |
+
4. DINO feature projection
|
| 942 |
+
5. Main UNet processing with attention conditioning
|
| 943 |
+
|
| 944 |
+
Args:
|
| 945 |
+
sample: Input latents [B, N_pbr, N_gen, C, H, W]
|
| 946 |
+
cached_condition: Dictionary containing:
|
| 947 |
+
- embeds_normal: Normal map embeddings
|
| 948 |
+
- embeds_position: Position map embeddings
|
| 949 |
+
- ref_latents: Reference image latents
|
| 950 |
+
- dino_hidden_states: DINO features
|
| 951 |
+
- position_maps: 3D position maps
|
| 952 |
+
- mva_scale: Multiview attention scale
|
| 953 |
+
- ref_scale: Reference attention scale
|
| 954 |
+
|
| 955 |
+
Returns:
|
| 956 |
+
torch.Tensor: Output features
|
| 957 |
+
"""
|
| 958 |
+
|
| 959 |
+
B, N_pbr, N_gen, _, H, W = sample.shape
|
| 960 |
+
assert H == W
|
| 961 |
+
|
| 962 |
+
if "cache" not in cached_condition:
|
| 963 |
+
cached_condition["cache"] = {}
|
| 964 |
+
|
| 965 |
+
sample = [sample]
|
| 966 |
+
if "embeds_normal" in cached_condition:
|
| 967 |
+
sample.append(cached_condition["embeds_normal"].unsqueeze(1).repeat(1, N_pbr, 1, 1, 1, 1))
|
| 968 |
+
if "embeds_position" in cached_condition:
|
| 969 |
+
sample.append(cached_condition["embeds_position"].unsqueeze(1).repeat(1, N_pbr, 1, 1, 1, 1))
|
| 970 |
+
sample = torch.cat(sample, dim=-3)
|
| 971 |
+
|
| 972 |
+
sample = rearrange(sample, "b n_pbr n c h w -> (b n_pbr n) c h w")
|
| 973 |
+
|
| 974 |
+
encoder_hidden_states_gen = encoder_hidden_states.unsqueeze(-3).repeat(1, 1, N_gen, 1, 1)
|
| 975 |
+
encoder_hidden_states_gen = rearrange(encoder_hidden_states_gen, "b n_pbr n l c -> (b n_pbr n) l c")
|
| 976 |
+
|
| 977 |
+
if added_cond_kwargs is not None:
|
| 978 |
+
text_embeds_gen = added_cond_kwargs["text_embeds"].unsqueeze(1).repeat(1, N_gen, 1)
|
| 979 |
+
text_embeds_gen = rearrange(text_embeds_gen, "b n c -> (b n) c")
|
| 980 |
+
time_ids_gen = added_cond_kwargs["time_ids"].unsqueeze(1).repeat(1, N_gen, 1)
|
| 981 |
+
time_ids_gen = rearrange(time_ids_gen, "b n c -> (b n) c")
|
| 982 |
+
added_cond_kwargs_gen = {"text_embeds": text_embeds_gen, "time_ids": time_ids_gen}
|
| 983 |
+
else:
|
| 984 |
+
added_cond_kwargs_gen = None
|
| 985 |
+
|
| 986 |
+
if self.use_position_rope:
|
| 987 |
+
if "position_voxel_indices" in cached_condition["cache"]:
|
| 988 |
+
position_voxel_indices = cached_condition["cache"]["position_voxel_indices"]
|
| 989 |
+
else:
|
| 990 |
+
if "position_maps" in cached_condition:
|
| 991 |
+
position_voxel_indices = calc_multires_voxel_idxs(
|
| 992 |
+
cached_condition["position_maps"],
|
| 993 |
+
grid_resolutions=[H, H // 2, H // 4, H // 8],
|
| 994 |
+
voxel_resolutions=[H * 8, H * 4, H * 2, H],
|
| 995 |
+
)
|
| 996 |
+
cached_condition["cache"]["position_voxel_indices"] = position_voxel_indices
|
| 997 |
+
else:
|
| 998 |
+
position_voxel_indices = None
|
| 999 |
+
|
| 1000 |
+
if self.use_dino:
|
| 1001 |
+
if "dino_hidden_states_proj" in cached_condition["cache"]:
|
| 1002 |
+
dino_hidden_states = cached_condition["cache"]["dino_hidden_states_proj"]
|
| 1003 |
+
else:
|
| 1004 |
+
assert "dino_hidden_states" in cached_condition
|
| 1005 |
+
dino_hidden_states = cached_condition["dino_hidden_states"]
|
| 1006 |
+
dino_hidden_states = self.image_proj_model_dino(dino_hidden_states)
|
| 1007 |
+
cached_condition["cache"]["dino_hidden_states_proj"] = dino_hidden_states
|
| 1008 |
+
else:
|
| 1009 |
+
dino_hidden_states = None
|
| 1010 |
+
|
| 1011 |
+
if self.use_ra:
|
| 1012 |
+
if "condition_embed_dict" in cached_condition["cache"]:
|
| 1013 |
+
condition_embed_dict = cached_condition["cache"]["condition_embed_dict"]
|
| 1014 |
+
else:
|
| 1015 |
+
condition_embed_dict = {}
|
| 1016 |
+
ref_latents = cached_condition["ref_latents"]
|
| 1017 |
+
N_ref = ref_latents.shape[1]
|
| 1018 |
+
|
| 1019 |
+
if not self.use_dual_stream:
|
| 1020 |
+
ref_latents = [ref_latents]
|
| 1021 |
+
if "embeds_normal" in cached_condition:
|
| 1022 |
+
ref_latents.append(torch.zeros_like(ref_latents[0]))
|
| 1023 |
+
if "embeds_position" in cached_condition:
|
| 1024 |
+
ref_latents.append(torch.zeros_like(ref_latents[0]))
|
| 1025 |
+
ref_latents = torch.cat(ref_latents, dim=2)
|
| 1026 |
+
|
| 1027 |
+
ref_latents = rearrange(ref_latents, "b n c h w -> (b n) c h w")
|
| 1028 |
+
|
| 1029 |
+
encoder_hidden_states_ref = self.unet.learned_text_clip_ref.repeat(B, N_ref, 1, 1)
|
| 1030 |
+
|
| 1031 |
+
encoder_hidden_states_ref = rearrange(encoder_hidden_states_ref, "b n l c -> (b n) l c")
|
| 1032 |
+
|
| 1033 |
+
if added_cond_kwargs is not None:
|
| 1034 |
+
text_embeds_ref = added_cond_kwargs["text_embeds"].unsqueeze(1).repeat(1, N_ref, 1)
|
| 1035 |
+
text_embeds_ref = rearrange(text_embeds_ref, "b n c -> (b n) c")
|
| 1036 |
+
time_ids_ref = added_cond_kwargs["time_ids"].unsqueeze(1).repeat(1, N_ref, 1)
|
| 1037 |
+
time_ids_ref = rearrange(time_ids_ref, "b n c -> (b n) c")
|
| 1038 |
+
added_cond_kwargs_ref = {
|
| 1039 |
+
"text_embeds": text_embeds_ref,
|
| 1040 |
+
"time_ids": time_ids_ref,
|
| 1041 |
+
}
|
| 1042 |
+
else:
|
| 1043 |
+
added_cond_kwargs_ref = None
|
| 1044 |
+
|
| 1045 |
+
noisy_ref_latents = ref_latents
|
| 1046 |
+
timestep_ref = 0
|
| 1047 |
+
if self.use_dual_stream:
|
| 1048 |
+
unet_ref = self.unet_dual
|
| 1049 |
+
else:
|
| 1050 |
+
unet_ref = self.unet
|
| 1051 |
+
unet_ref(
|
| 1052 |
+
noisy_ref_latents,
|
| 1053 |
+
timestep_ref,
|
| 1054 |
+
encoder_hidden_states=encoder_hidden_states_ref,
|
| 1055 |
+
class_labels=None,
|
| 1056 |
+
added_cond_kwargs=added_cond_kwargs_ref,
|
| 1057 |
+
# **kwargs
|
| 1058 |
+
return_dict=False,
|
| 1059 |
+
cross_attention_kwargs={
|
| 1060 |
+
"mode": "w",
|
| 1061 |
+
"num_in_batch": N_ref,
|
| 1062 |
+
"condition_embed_dict": condition_embed_dict,
|
| 1063 |
+
},
|
| 1064 |
+
)
|
| 1065 |
+
cached_condition["cache"]["condition_embed_dict"] = condition_embed_dict
|
| 1066 |
+
else:
|
| 1067 |
+
condition_embed_dict = None
|
| 1068 |
+
|
| 1069 |
+
mva_scale = cached_condition.get("mva_scale", 1.0)
|
| 1070 |
+
ref_scale = cached_condition.get("ref_scale", 1.0)
|
| 1071 |
+
|
| 1072 |
+
return self.unet(
|
| 1073 |
+
sample,
|
| 1074 |
+
timestep,
|
| 1075 |
+
encoder_hidden_states_gen,
|
| 1076 |
+
*args,
|
| 1077 |
+
class_labels=None,
|
| 1078 |
+
added_cond_kwargs=added_cond_kwargs_gen,
|
| 1079 |
+
down_intrablock_additional_residuals=(
|
| 1080 |
+
[sample.to(dtype=self.unet.dtype) for sample in down_intrablock_additional_residuals]
|
| 1081 |
+
if down_intrablock_additional_residuals is not None
|
| 1082 |
+
else None
|
| 1083 |
+
),
|
| 1084 |
+
down_block_additional_residuals=(
|
| 1085 |
+
[sample.to(dtype=self.unet.dtype) for sample in down_block_res_samples]
|
| 1086 |
+
if down_block_res_samples is not None
|
| 1087 |
+
else None
|
| 1088 |
+
),
|
| 1089 |
+
mid_block_additional_residual=(
|
| 1090 |
+
mid_block_res_sample.to(dtype=self.unet.dtype) if mid_block_res_sample is not None else None
|
| 1091 |
+
),
|
| 1092 |
+
return_dict=False,
|
| 1093 |
+
cross_attention_kwargs={
|
| 1094 |
+
"mode": "r",
|
| 1095 |
+
"num_in_batch": N_gen,
|
| 1096 |
+
"dino_hidden_states": dino_hidden_states,
|
| 1097 |
+
"condition_embed_dict": condition_embed_dict,
|
| 1098 |
+
"mva_scale": mva_scale,
|
| 1099 |
+
"ref_scale": ref_scale,
|
| 1100 |
+
"position_voxel_indices": position_voxel_indices,
|
| 1101 |
+
},
|
| 1102 |
+
)
|
hunyuan3d-paintpbr-v2-1/vae/config.json
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "AutoencoderKL",
|
| 3 |
+
"_diffusers_version": "0.10.0.dev0",
|
| 4 |
+
"act_fn": "silu",
|
| 5 |
+
"block_out_channels": [
|
| 6 |
+
128,
|
| 7 |
+
256,
|
| 8 |
+
512,
|
| 9 |
+
512
|
| 10 |
+
],
|
| 11 |
+
"down_block_types": [
|
| 12 |
+
"DownEncoderBlock2D",
|
| 13 |
+
"DownEncoderBlock2D",
|
| 14 |
+
"DownEncoderBlock2D",
|
| 15 |
+
"DownEncoderBlock2D"
|
| 16 |
+
],
|
| 17 |
+
"in_channels": 3,
|
| 18 |
+
"latent_channels": 4,
|
| 19 |
+
"layers_per_block": 2,
|
| 20 |
+
"norm_num_groups": 32,
|
| 21 |
+
"out_channels": 3,
|
| 22 |
+
"sample_size": 768,
|
| 23 |
+
"up_block_types": [
|
| 24 |
+
"UpDecoderBlock2D",
|
| 25 |
+
"UpDecoderBlock2D",
|
| 26 |
+
"UpDecoderBlock2D",
|
| 27 |
+
"UpDecoderBlock2D"
|
| 28 |
+
]
|
| 29 |
+
}
|
hunyuan3d-paintpbr-v2-1/vae/diffusion_pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1b4889b6b1d4ce7ae320a02dedaeff1780ad77d415ea0d744b476155c6377ddc
|
| 3 |
+
size 334707217
|
hunyuan3d-vae-v2-1/config.yaml
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
target: hy3dshape.models.ShapeVAE
|
| 2 |
+
params:
|
| 3 |
+
num_latents: 4096
|
| 4 |
+
embed_dim: 64
|
| 5 |
+
num_freqs: 8
|
| 6 |
+
include_pi: false
|
| 7 |
+
heads: 16
|
| 8 |
+
width: 1024
|
| 9 |
+
num_encoder_layers: 8
|
| 10 |
+
num_decoder_layers: 16
|
| 11 |
+
qkv_bias: false
|
| 12 |
+
qk_norm: true
|
| 13 |
+
scale_factor: 1.0039506158752403
|
| 14 |
+
geo_decoder_mlp_expand_ratio: 4
|
| 15 |
+
geo_decoder_downsample_ratio: 1
|
| 16 |
+
geo_decoder_ln_post: true
|
| 17 |
+
point_feats: 4
|
| 18 |
+
pc_size: 81920
|
| 19 |
+
pc_sharpedge_size: 0
|
hunyuan3d-vae-v2-1/model.fp16.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5cbe97f25e6e7abd4bccc80ab07524ec0c86d24118486a9ba49bb5dfb070288a
|
| 3 |
+
size 655648152
|
hy3dpaint/textureGenPipeline.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
| 2 |
+
# except for the third-party components listed below.
|
| 3 |
+
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
| 4 |
+
# in the repsective licenses of these third-party components.
|
| 5 |
+
# Users must comply with all terms and conditions of original licenses of these third-party
|
| 6 |
+
# components and must ensure that the usage of the third party components adheres to
|
| 7 |
+
# all relevant laws and regulations.
|
| 8 |
+
|
| 9 |
+
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
| 10 |
+
# their software and algorithms, including trained model weights, parameters (including
|
| 11 |
+
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
| 12 |
+
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
| 13 |
+
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import torch
|
| 17 |
+
import copy
|
| 18 |
+
import trimesh
|
| 19 |
+
import numpy as np
|
| 20 |
+
from PIL import Image
|
| 21 |
+
from typing import List
|
| 22 |
+
from DifferentiableRenderer.MeshRender import MeshRender
|
| 23 |
+
from utils.simplify_mesh_utils import remesh_mesh
|
| 24 |
+
from utils.multiview_utils import multiviewDiffusionNet
|
| 25 |
+
from utils.pipeline_utils import ViewProcessor
|
| 26 |
+
from utils.image_super_utils import imageSuperNet
|
| 27 |
+
from utils.uvwrap_utils import mesh_uv_wrap
|
| 28 |
+
from DifferentiableRenderer.mesh_utils import convert_obj_to_glb
|
| 29 |
+
import warnings
|
| 30 |
+
|
| 31 |
+
warnings.filterwarnings("ignore")
|
| 32 |
+
from diffusers.utils import logging as diffusers_logging
|
| 33 |
+
|
| 34 |
+
diffusers_logging.set_verbosity(50)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class Hunyuan3DPaintConfig:
|
| 38 |
+
def __init__(self, max_num_view, resolution):
|
| 39 |
+
self.device = "cuda"
|
| 40 |
+
|
| 41 |
+
self.multiview_cfg_path = "cfgs/hunyuan-paint-pbr.yaml"
|
| 42 |
+
self.custom_pipeline = "hunyuanpaintpbr"
|
| 43 |
+
self.multiview_pretrained_path = "tencent/Hunyuan3D-2.1"
|
| 44 |
+
self.dino_ckpt_path = "facebook/dinov2-giant"
|
| 45 |
+
self.realesrgan_ckpt_path = "ckpt/RealESRGAN_x4plus.pth"
|
| 46 |
+
|
| 47 |
+
self.raster_mode = "cr"
|
| 48 |
+
self.bake_mode = "back_sample"
|
| 49 |
+
self.render_size = 1024 * 2
|
| 50 |
+
self.texture_size = 1024 * 4
|
| 51 |
+
self.max_selected_view_num = max_num_view
|
| 52 |
+
self.resolution = resolution
|
| 53 |
+
self.bake_exp = 4
|
| 54 |
+
self.merge_method = "fast"
|
| 55 |
+
|
| 56 |
+
# view selection
|
| 57 |
+
self.candidate_camera_azims = [0, 90, 180, 270, 0, 180]
|
| 58 |
+
self.candidate_camera_elevs = [0, 0, 0, 0, 90, -90]
|
| 59 |
+
self.candidate_view_weights = [1, 0.1, 0.5, 0.1, 0.05, 0.05]
|
| 60 |
+
|
| 61 |
+
for azim in range(0, 360, 30):
|
| 62 |
+
self.candidate_camera_azims.append(azim)
|
| 63 |
+
self.candidate_camera_elevs.append(20)
|
| 64 |
+
self.candidate_view_weights.append(0.01)
|
| 65 |
+
|
| 66 |
+
self.candidate_camera_azims.append(azim)
|
| 67 |
+
self.candidate_camera_elevs.append(-20)
|
| 68 |
+
self.candidate_view_weights.append(0.01)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class Hunyuan3DPaintPipeline:
|
| 72 |
+
|
| 73 |
+
def __init__(self, config=None) -> None:
|
| 74 |
+
self.config = config if config is not None else Hunyuan3DPaintConfig()
|
| 75 |
+
self.models = {}
|
| 76 |
+
self.stats_logs = {}
|
| 77 |
+
self.render = MeshRender(
|
| 78 |
+
default_resolution=self.config.render_size,
|
| 79 |
+
texture_size=self.config.texture_size,
|
| 80 |
+
bake_mode=self.config.bake_mode,
|
| 81 |
+
raster_mode=self.config.raster_mode,
|
| 82 |
+
)
|
| 83 |
+
self.view_processor = ViewProcessor(self.config, self.render)
|
| 84 |
+
self.load_models()
|
| 85 |
+
|
| 86 |
+
def load_models(self):
|
| 87 |
+
torch.cuda.empty_cache()
|
| 88 |
+
self.models["super_model"] = imageSuperNet(self.config)
|
| 89 |
+
self.models["multiview_model"] = multiviewDiffusionNet(self.config)
|
| 90 |
+
print("Models Loaded.")
|
| 91 |
+
|
| 92 |
+
@torch.no_grad()
|
| 93 |
+
def __call__(self, mesh_path=None, image_path=None, output_mesh_path=None, use_remesh=True, save_glb=True):
|
| 94 |
+
"""Generate texture for 3D mesh using multiview diffusion"""
|
| 95 |
+
# Ensure image_prompt is a list
|
| 96 |
+
if isinstance(image_path, str):
|
| 97 |
+
image_prompt = Image.open(image_path)
|
| 98 |
+
elif isinstance(image_path, Image.Image):
|
| 99 |
+
image_prompt = image_path
|
| 100 |
+
if not isinstance(image_prompt, List):
|
| 101 |
+
image_prompt = [image_prompt]
|
| 102 |
+
else:
|
| 103 |
+
image_prompt = image_path
|
| 104 |
+
|
| 105 |
+
# Process mesh
|
| 106 |
+
path = os.path.dirname(mesh_path)
|
| 107 |
+
if use_remesh:
|
| 108 |
+
processed_mesh_path = os.path.join(path, "white_mesh_remesh.obj")
|
| 109 |
+
remesh_mesh(mesh_path, processed_mesh_path)
|
| 110 |
+
else:
|
| 111 |
+
processed_mesh_path = mesh_path
|
| 112 |
+
|
| 113 |
+
# Output path
|
| 114 |
+
if output_mesh_path is None:
|
| 115 |
+
output_mesh_path = os.path.join(path, f"textured_mesh.obj")
|
| 116 |
+
|
| 117 |
+
# Load mesh
|
| 118 |
+
mesh = trimesh.load(processed_mesh_path)
|
| 119 |
+
mesh = mesh_uv_wrap(mesh)
|
| 120 |
+
self.render.load_mesh(mesh=mesh)
|
| 121 |
+
|
| 122 |
+
########### View Selection #########
|
| 123 |
+
selected_camera_elevs, selected_camera_azims, selected_view_weights = self.view_processor.bake_view_selection(
|
| 124 |
+
self.config.candidate_camera_elevs,
|
| 125 |
+
self.config.candidate_camera_azims,
|
| 126 |
+
self.config.candidate_view_weights,
|
| 127 |
+
self.config.max_selected_view_num,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
normal_maps = self.view_processor.render_normal_multiview(
|
| 131 |
+
selected_camera_elevs, selected_camera_azims, use_abs_coor=True
|
| 132 |
+
)
|
| 133 |
+
position_maps = self.view_processor.render_position_multiview(selected_camera_elevs, selected_camera_azims)
|
| 134 |
+
|
| 135 |
+
########## Style ###########
|
| 136 |
+
image_caption = "high quality"
|
| 137 |
+
image_style = []
|
| 138 |
+
for image in image_prompt:
|
| 139 |
+
image = image.resize((512, 512))
|
| 140 |
+
if image.mode == "RGBA":
|
| 141 |
+
white_bg = Image.new("RGB", image.size, (255, 255, 255))
|
| 142 |
+
white_bg.paste(image, mask=image.getchannel("A"))
|
| 143 |
+
image = white_bg
|
| 144 |
+
image_style.append(image)
|
| 145 |
+
image_style = [image.convert("RGB") for image in image_style]
|
| 146 |
+
|
| 147 |
+
########### Multiview ##########
|
| 148 |
+
multiviews_pbr = self.models["multiview_model"](
|
| 149 |
+
image_style,
|
| 150 |
+
normal_maps + position_maps,
|
| 151 |
+
prompt=image_caption,
|
| 152 |
+
custom_view_size=self.config.resolution,
|
| 153 |
+
resize_input=True,
|
| 154 |
+
)
|
| 155 |
+
########### Enhance ##########
|
| 156 |
+
enhance_images = {}
|
| 157 |
+
enhance_images["albedo"] = copy.deepcopy(multiviews_pbr["albedo"])
|
| 158 |
+
enhance_images["mr"] = copy.deepcopy(multiviews_pbr["mr"])
|
| 159 |
+
|
| 160 |
+
for i in range(len(enhance_images["albedo"])):
|
| 161 |
+
enhance_images["albedo"][i] = self.models["super_model"](enhance_images["albedo"][i])
|
| 162 |
+
enhance_images["mr"][i] = self.models["super_model"](enhance_images["mr"][i])
|
| 163 |
+
|
| 164 |
+
########### Bake ##########
|
| 165 |
+
for i in range(len(enhance_images)):
|
| 166 |
+
enhance_images["albedo"][i] = enhance_images["albedo"][i].resize(
|
| 167 |
+
(self.config.render_size, self.config.render_size)
|
| 168 |
+
)
|
| 169 |
+
enhance_images["mr"][i] = enhance_images["mr"][i].resize((self.config.render_size, self.config.render_size))
|
| 170 |
+
texture, mask = self.view_processor.bake_from_multiview(
|
| 171 |
+
enhance_images["albedo"], selected_camera_elevs, selected_camera_azims, selected_view_weights
|
| 172 |
+
)
|
| 173 |
+
mask_np = (mask.squeeze(-1).cpu().numpy() * 255).astype(np.uint8)
|
| 174 |
+
texture_mr, mask_mr = self.view_processor.bake_from_multiview(
|
| 175 |
+
enhance_images["mr"], selected_camera_elevs, selected_camera_azims, selected_view_weights
|
| 176 |
+
)
|
| 177 |
+
mask_mr_np = (mask_mr.squeeze(-1).cpu().numpy() * 255).astype(np.uint8)
|
| 178 |
+
|
| 179 |
+
########## inpaint ###########
|
| 180 |
+
texture = self.view_processor.texture_inpaint(texture, mask_np)
|
| 181 |
+
self.render.set_texture(texture, force_set=True)
|
| 182 |
+
if "mr" in enhance_images:
|
| 183 |
+
texture_mr = self.view_processor.texture_inpaint(texture_mr, mask_mr_np)
|
| 184 |
+
self.render.set_texture_mr(texture_mr)
|
| 185 |
+
|
| 186 |
+
self.render.save_mesh(output_mesh_path, downsample=True)
|
| 187 |
+
|
| 188 |
+
if save_glb:
|
| 189 |
+
convert_obj_to_glb(output_mesh_path, output_mesh_path.replace(".obj", ".glb"))
|
| 190 |
+
output_glb_path = output_mesh_path.replace(".obj", ".glb")
|
| 191 |
+
|
| 192 |
+
return output_mesh_path
|
hy3dpaint/utils/multiview_utils.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
| 2 |
+
# except for the third-party components listed below.
|
| 3 |
+
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
| 4 |
+
# in the repsective licenses of these third-party components.
|
| 5 |
+
# Users must comply with all terms and conditions of original licenses of these third-party
|
| 6 |
+
# components and must ensure that the usage of the third party components adheres to
|
| 7 |
+
# all relevant laws and regulations.
|
| 8 |
+
|
| 9 |
+
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
| 10 |
+
# their software and algorithms, including trained model weights, parameters (including
|
| 11 |
+
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
| 12 |
+
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
| 13 |
+
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import torch
|
| 17 |
+
import random
|
| 18 |
+
import numpy as np
|
| 19 |
+
from PIL import Image
|
| 20 |
+
from typing import List
|
| 21 |
+
import huggingface_hub
|
| 22 |
+
from omegaconf import OmegaConf
|
| 23 |
+
from diffusers import DiffusionPipeline
|
| 24 |
+
from diffusers import EulerAncestralDiscreteScheduler, DDIMScheduler, UniPCMultistepScheduler
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class multiviewDiffusionNet:
|
| 28 |
+
def __init__(self, config) -> None:
|
| 29 |
+
self.device = config.device
|
| 30 |
+
|
| 31 |
+
cfg_path = config.multiview_cfg_path
|
| 32 |
+
custom_pipeline = config.custom_pipeline
|
| 33 |
+
cfg = OmegaConf.load(cfg_path)
|
| 34 |
+
self.cfg = cfg
|
| 35 |
+
self.mode = self.cfg.model.params.stable_diffusion_config.custom_pipeline[2:]
|
| 36 |
+
|
| 37 |
+
model_path = huggingface_hub.snapshot_download(
|
| 38 |
+
repo_id=config.multiview_pretrained_path,
|
| 39 |
+
allow_patterns=["hunyuan3d-paintpbr-v2-1/*"],
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
model_path = os.path.join(model_path, "hunyuan3d-paintpbr-v2-1")
|
| 43 |
+
pipeline = DiffusionPipeline.from_pretrained(
|
| 44 |
+
model_path,
|
| 45 |
+
custom_pipeline=custom_pipeline,
|
| 46 |
+
torch_dtype=torch.float16
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config, timestep_spacing="trailing")
|
| 50 |
+
pipeline.set_progress_bar_config(disable=True)
|
| 51 |
+
pipeline.eval()
|
| 52 |
+
setattr(pipeline, "view_size", cfg.model.params.get("view_size", 320))
|
| 53 |
+
self.pipeline = pipeline.to(self.device)
|
| 54 |
+
|
| 55 |
+
if hasattr(self.pipeline.unet, "use_dino") and self.pipeline.unet.use_dino:
|
| 56 |
+
from hunyuanpaintpbr.unet.modules import Dino_v2
|
| 57 |
+
self.dino_v2 = Dino_v2(config.dino_ckpt_path).to(torch.float16)
|
| 58 |
+
self.dino_v2 = self.dino_v2.to(self.device)
|
| 59 |
+
|
| 60 |
+
def seed_everything(self, seed):
|
| 61 |
+
random.seed(seed)
|
| 62 |
+
np.random.seed(seed)
|
| 63 |
+
torch.manual_seed(seed)
|
| 64 |
+
os.environ["PL_GLOBAL_SEED"] = str(seed)
|
| 65 |
+
|
| 66 |
+
@torch.no_grad()
|
| 67 |
+
def __call__(self, images, conditions, prompt=None, custom_view_size=None, resize_input=False):
|
| 68 |
+
pils = self.forward_one(
|
| 69 |
+
images, conditions, prompt=prompt, custom_view_size=custom_view_size, resize_input=resize_input
|
| 70 |
+
)
|
| 71 |
+
return pils
|
| 72 |
+
|
| 73 |
+
def forward_one(self, input_images, control_images, prompt=None, custom_view_size=None, resize_input=False):
|
| 74 |
+
self.seed_everything(0)
|
| 75 |
+
custom_view_size = custom_view_size if custom_view_size is not None else self.pipeline.view_size
|
| 76 |
+
if not isinstance(input_images, List):
|
| 77 |
+
input_images = [input_images]
|
| 78 |
+
if not resize_input:
|
| 79 |
+
input_images = [
|
| 80 |
+
input_image.resize((self.pipeline.view_size, self.pipeline.view_size)) for input_image in input_images
|
| 81 |
+
]
|
| 82 |
+
else:
|
| 83 |
+
input_images = [input_image.resize((custom_view_size, custom_view_size)) for input_image in input_images]
|
| 84 |
+
for i in range(len(control_images)):
|
| 85 |
+
control_images[i] = control_images[i].resize((custom_view_size, custom_view_size))
|
| 86 |
+
if control_images[i].mode == "L":
|
| 87 |
+
control_images[i] = control_images[i].point(lambda x: 255 if x > 1 else 0, mode="1")
|
| 88 |
+
kwargs = dict(generator=torch.Generator(device=self.pipeline.device).manual_seed(0))
|
| 89 |
+
|
| 90 |
+
num_view = len(control_images) // 2
|
| 91 |
+
normal_image = [[control_images[i] for i in range(num_view)]]
|
| 92 |
+
position_image = [[control_images[i + num_view] for i in range(num_view)]]
|
| 93 |
+
|
| 94 |
+
kwargs["width"] = custom_view_size
|
| 95 |
+
kwargs["height"] = custom_view_size
|
| 96 |
+
kwargs["num_in_batch"] = num_view
|
| 97 |
+
kwargs["images_normal"] = normal_image
|
| 98 |
+
kwargs["images_position"] = position_image
|
| 99 |
+
|
| 100 |
+
if hasattr(self.pipeline.unet, "use_dino") and self.pipeline.unet.use_dino:
|
| 101 |
+
dino_hidden_states = self.dino_v2(input_images[0])
|
| 102 |
+
kwargs["dino_hidden_states"] = dino_hidden_states
|
| 103 |
+
|
| 104 |
+
sync_condition = None
|
| 105 |
+
|
| 106 |
+
infer_steps_dict = {
|
| 107 |
+
"EulerAncestralDiscreteScheduler": 30,
|
| 108 |
+
"UniPCMultistepScheduler": 15,
|
| 109 |
+
"DDIMScheduler": 50,
|
| 110 |
+
"ShiftSNRScheduler": 15,
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
mvd_image = self.pipeline(
|
| 114 |
+
input_images[0:1],
|
| 115 |
+
num_inference_steps=infer_steps_dict[self.pipeline.scheduler.__class__.__name__],
|
| 116 |
+
prompt=prompt,
|
| 117 |
+
sync_condition=sync_condition,
|
| 118 |
+
guidance_scale=3.0,
|
| 119 |
+
**kwargs,
|
| 120 |
+
).images
|
| 121 |
+
|
| 122 |
+
if "pbr" in self.mode:
|
| 123 |
+
mvd_image = {"albedo": mvd_image[:num_view], "mr": mvd_image[num_view:]}
|
| 124 |
+
# mvd_image = {'albedo':mvd_image[:num_view]}
|
| 125 |
+
else:
|
| 126 |
+
mvd_image = {"hdr": mvd_image}
|
| 127 |
+
|
| 128 |
+
return mvd_image
|