fffiloni commited on
Commit
459fa69
·
verified ·
1 Parent(s): 6ed5b2d

Migrated from GitHub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ docs/teaser.gif filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Creative Commons Attribution-NonCommercial 4.0 International
2
+
3
+ Creative Commons Corporation ("Creative Commons") is not a law firm and
4
+ does not provide legal services or legal advice. Distribution of
5
+ Creative Commons public licenses does not create a lawyer-client or
6
+ other relationship. Creative Commons makes its licenses and related
7
+ information available on an "as-is" basis. Creative Commons gives no
8
+ warranties regarding its licenses, any material licensed under their
9
+ terms and conditions, or any related information. Creative Commons
10
+ disclaims all liability for damages resulting from their use to the
11
+ fullest extent possible.
12
+
13
+ Using Creative Commons Public Licenses
14
+
15
+ Creative Commons public licenses provide a standard set of terms and
16
+ conditions that creators and other rights holders may use to share
17
+ original works of authorship and other material subject to copyright and
18
+ certain other rights specified in the public license below. The
19
+ following considerations are for informational purposes only, are not
20
+ exhaustive, and do not form part of our licenses.
21
+
22
+ - Considerations for licensors: Our public licenses are intended for
23
+ use by those authorized to give the public permission to use
24
+ material in ways otherwise restricted by copyright and certain other
25
+ rights. Our licenses are irrevocable. Licensors should read and
26
+ understand the terms and conditions of the license they choose
27
+ before applying it. Licensors should also secure all rights
28
+ necessary before applying our licenses so that the public can reuse
29
+ the material as expected. Licensors should clearly mark any material
30
+ not subject to the license. This includes other CC-licensed
31
+ material, or material used under an exception or limitation to
32
+ copyright. More considerations for licensors :
33
+ wiki.creativecommons.org/Considerations_for_licensors
34
+
35
+ - Considerations for the public: By using one of our public licenses,
36
+ a licensor grants the public permission to use the licensed material
37
+ under specified terms and conditions. If the licensor's permission
38
+ is not necessary for any reason–for example, because of any
39
+ applicable exception or limitation to copyright–then that use is not
40
+ regulated by the license. Our licenses grant only permissions under
41
+ copyright and certain other rights that a licensor has authority to
42
+ grant. Use of the licensed material may still be restricted for
43
+ other reasons, including because others have copyright or other
44
+ rights in the material. A licensor may make special requests, such
45
+ as asking that all changes be marked or described. Although not
46
+ required by our licenses, you are encouraged to respect those
47
+ requests where reasonable. More considerations for the public :
48
+ wiki.creativecommons.org/Considerations_for_licensees
49
+
50
+ Creative Commons Attribution-NonCommercial 4.0 International Public
51
+ License
52
+
53
+ By exercising the Licensed Rights (defined below), You accept and agree
54
+ to be bound by the terms and conditions of this Creative Commons
55
+ Attribution-NonCommercial 4.0 International Public License ("Public
56
+ License"). To the extent this Public License may be interpreted as a
57
+ contract, You are granted the Licensed Rights in consideration of Your
58
+ acceptance of these terms and conditions, and the Licensor grants You
59
+ such rights in consideration of benefits the Licensor receives from
60
+ making the Licensed Material available under these terms and conditions.
61
+
62
+ - Section 1 – Definitions.
63
+
64
+ - a. Adapted Material means material subject to Copyright and
65
+ Similar Rights that is derived from or based upon the Licensed
66
+ Material and in which the Licensed Material is translated,
67
+ altered, arranged, transformed, or otherwise modified in a
68
+ manner requiring permission under the Copyright and Similar
69
+ Rights held by the Licensor. For purposes of this Public
70
+ License, where the Licensed Material is a musical work,
71
+ performance, or sound recording, Adapted Material is always
72
+ produced where the Licensed Material is synched in timed
73
+ relation with a moving image.
74
+ - b. Adapter's License means the license You apply to Your
75
+ Copyright and Similar Rights in Your contributions to Adapted
76
+ Material in accordance with the terms and conditions of this
77
+ Public License.
78
+ - c. Copyright and Similar Rights means copyright and/or similar
79
+ rights closely related to copyright including, without
80
+ limitation, performance, broadcast, sound recording, and Sui
81
+ Generis Database Rights, without regard to how the rights are
82
+ labeled or categorized. For purposes of this Public License, the
83
+ rights specified in Section 2(b)(1)-(2) are not Copyright and
84
+ Similar Rights.
85
+ - d. Effective Technological Measures means those measures that,
86
+ in the absence of proper authority, may not be circumvented
87
+ under laws fulfilling obligations under Article 11 of the WIPO
88
+ Copyright Treaty adopted on December 20, 1996, and/or similar
89
+ international agreements.
90
+ - e. Exceptions and Limitations means fair use, fair dealing,
91
+ and/or any other exception or limitation to Copyright and
92
+ Similar Rights that applies to Your use of the Licensed
93
+ Material.
94
+ - f. Licensed Material means the artistic or literary work,
95
+ database, or other material to which the Licensor applied this
96
+ Public License.
97
+ - g. Licensed Rights means the rights granted to You subject to
98
+ the terms and conditions of this Public License, which are
99
+ limited to all Copyright and Similar Rights that apply to Your
100
+ use of the Licensed Material and that the Licensor has authority
101
+ to license.
102
+ - h. Licensor means the individual(s) or entity(ies) granting
103
+ rights under this Public License.
104
+ - i. NonCommercial means not primarily intended for or directed
105
+ towards commercial advantage or monetary compensation. For
106
+ purposes of this Public License, the exchange of the Licensed
107
+ Material for other material subject to Copyright and Similar
108
+ Rights by digital file-sharing or similar means is NonCommercial
109
+ provided there is no payment of monetary compensation in
110
+ connection with the exchange.
111
+ - j. Share means to provide material to the public by any means or
112
+ process that requires permission under the Licensed Rights, such
113
+ as reproduction, public display, public performance,
114
+ distribution, dissemination, communication, or importation, and
115
+ to make material available to the public including in ways that
116
+ members of the public may access the material from a place and
117
+ at a time individually chosen by them.
118
+ - k. Sui Generis Database Rights means rights other than copyright
119
+ resulting from Directive 96/9/EC of the European Parliament and
120
+ of the Council of 11 March 1996 on the legal protection of
121
+ databases, as amended and/or succeeded, as well as other
122
+ essentially equivalent rights anywhere in the world.
123
+ - l. You means the individual or entity exercising the Licensed
124
+ Rights under this Public License. Your has a corresponding
125
+ meaning.
126
+
127
+ - Section 2 – Scope.
128
+
129
+ - a. License grant.
130
+ - 1. Subject to the terms and conditions of this Public
131
+ License, the Licensor hereby grants You a worldwide,
132
+ royalty-free, non-sublicensable, non-exclusive, irrevocable
133
+ license to exercise the Licensed Rights in the Licensed
134
+ Material to:
135
+ - A. reproduce and Share the Licensed Material, in whole
136
+ or in part, for NonCommercial purposes only; and
137
+ - B. produce, reproduce, and Share Adapted Material for
138
+ NonCommercial purposes only.
139
+ - 2. Exceptions and Limitations. For the avoidance of doubt,
140
+ where Exceptions and Limitations apply to Your use, this
141
+ Public License does not apply, and You do not need to comply
142
+ with its terms and conditions.
143
+ - 3. Term. The term of this Public License is specified in
144
+ Section 6(a).
145
+ - 4. Media and formats; technical modifications allowed. The
146
+ Licensor authorizes You to exercise the Licensed Rights in
147
+ all media and formats whether now known or hereafter
148
+ created, and to make technical modifications necessary to do
149
+ so. The Licensor waives and/or agrees not to assert any
150
+ right or authority to forbid You from making technical
151
+ modifications necessary to exercise the Licensed Rights,
152
+ including technical modifications necessary to circumvent
153
+ Effective Technological Measures. For purposes of this
154
+ Public License, simply making modifications authorized by
155
+ this Section 2(a)(4) never produces Adapted Material.
156
+ - 5. Downstream recipients.
157
+ - A. Offer from the Licensor – Licensed Material. Every
158
+ recipient of the Licensed Material automatically
159
+ receives an offer from the Licensor to exercise the
160
+ Licensed Rights under the terms and conditions of this
161
+ Public License.
162
+ - B. No downstream restrictions. You may not offer or
163
+ impose any additional or different terms or conditions
164
+ on, or apply any Effective Technological Measures to,
165
+ the Licensed Material if doing so restricts exercise of
166
+ the Licensed Rights by any recipient of the Licensed
167
+ Material.
168
+ - 6. No endorsement. Nothing in this Public License
169
+ constitutes or may be construed as permission to assert or
170
+ imply that You are, or that Your use of the Licensed
171
+ Material is, connected with, or sponsored, endorsed, or
172
+ granted official status by, the Licensor or others
173
+ designated to receive attribution as provided in Section
174
+ 3(a)(1)(A)(i).
175
+ - b. Other rights.
176
+ - 1. Moral rights, such as the right of integrity, are not
177
+ licensed under this Public License, nor are publicity,
178
+ privacy, and/or other similar personality rights; however,
179
+ to the extent possible, the Licensor waives and/or agrees
180
+ not to assert any such rights held by the Licensor to the
181
+ limited extent necessary to allow You to exercise the
182
+ Licensed Rights, but not otherwise.
183
+ - 2. Patent and trademark rights are not licensed under this
184
+ Public License.
185
+ - 3. To the extent possible, the Licensor waives any right to
186
+ collect royalties from You for the exercise of the Licensed
187
+ Rights, whether directly or through a collecting society
188
+ under any voluntary or waivable statutory or compulsory
189
+ licensing scheme. In all other cases the Licensor expressly
190
+ reserves any right to collect such royalties, including when
191
+ the Licensed Material is used other than for NonCommercial
192
+ purposes.
193
+
194
+ - Section 3 – License Conditions.
195
+
196
+ Your exercise of the Licensed Rights is expressly made subject to
197
+ the following conditions.
198
+
199
+ - a. Attribution.
200
+ - 1. If You Share the Licensed Material (including in modified
201
+ form), You must:
202
+ - A. retain the following if it is supplied by the
203
+ Licensor with the Licensed Material:
204
+ - i. identification of the creator(s) of the Licensed
205
+ Material and any others designated to receive
206
+ attribution, in any reasonable manner requested by
207
+ the Licensor (including by pseudonym if designated);
208
+ - ii. a copyright notice;
209
+ - iii. a notice that refers to this Public License;
210
+ - iv. a notice that refers to the disclaimer of
211
+ warranties;
212
+ - v. a URI or hyperlink to the Licensed Material to
213
+ the extent reasonably practicable;
214
+ - B. indicate if You modified the Licensed Material and
215
+ retain an indication of any previous modifications; and
216
+ - C. indicate the Licensed Material is licensed under this
217
+ Public License, and include the text of, or the URI or
218
+ hyperlink to, this Public License.
219
+ - 2. You may satisfy the conditions in Section 3(a)(1) in any
220
+ reasonable manner based on the medium, means, and context in
221
+ which You Share the Licensed Material. For example, it may
222
+ be reasonable to satisfy the conditions by providing a URI
223
+ or hyperlink to a resource that includes the required
224
+ information.
225
+ - 3. If requested by the Licensor, You must remove any of the
226
+ information required by Section 3(a)(1)(A) to the extent
227
+ reasonably practicable.
228
+ - 4. If You Share Adapted Material You produce, the Adapter's
229
+ License You apply must not prevent recipients of the Adapted
230
+ Material from complying with this Public License.
231
+
232
+ - Section 4 – Sui Generis Database Rights.
233
+
234
+ Where the Licensed Rights include Sui Generis Database Rights that
235
+ apply to Your use of the Licensed Material:
236
+
237
+ - a. for the avoidance of doubt, Section 2(a)(1) grants You the
238
+ right to extract, reuse, reproduce, and Share all or a
239
+ substantial portion of the contents of the database for
240
+ NonCommercial purposes only;
241
+ - b. if You include all or a substantial portion of the database
242
+ contents in a database in which You have Sui Generis Database
243
+ Rights, then the database in which You have Sui Generis Database
244
+ Rights (but not its individual contents) is Adapted Material;
245
+ and
246
+ - c. You must comply with the conditions in Section 3(a) if You
247
+ Share all or a substantial portion of the contents of the
248
+ database.
249
+
250
+ For the avoidance of doubt, this Section 4 supplements and does not
251
+ replace Your obligations under this Public License where the
252
+ Licensed Rights include other Copyright and Similar Rights.
253
+
254
+ - Section 5 – Disclaimer of Warranties and Limitation of Liability.
255
+
256
+ - a. Unless otherwise separately undertaken by the Licensor, to
257
+ the extent possible, the Licensor offers the Licensed Material
258
+ as-is and as-available, and makes no representations or
259
+ warranties of any kind concerning the Licensed Material, whether
260
+ express, implied, statutory, or other. This includes, without
261
+ limitation, warranties of title, merchantability, fitness for a
262
+ particular purpose, non-infringement, absence of latent or other
263
+ defects, accuracy, or the presence or absence of errors, whether
264
+ or not known or discoverable. Where disclaimers of warranties
265
+ are not allowed in full or in part, this disclaimer may not
266
+ apply to You.
267
+ - b. To the extent possible, in no event will the Licensor be
268
+ liable to You on any legal theory (including, without
269
+ limitation, negligence) or otherwise for any direct, special,
270
+ indirect, incidental, consequential, punitive, exemplary, or
271
+ other losses, costs, expenses, or damages arising out of this
272
+ Public License or use of the Licensed Material, even if the
273
+ Licensor has been advised of the possibility of such losses,
274
+ costs, expenses, or damages. Where a limitation of liability is
275
+ not allowed in full or in part, this limitation may not apply to
276
+ You.
277
+ - c. The disclaimer of warranties and limitation of liability
278
+ provided above shall be interpreted in a manner that, to the
279
+ extent possible, most closely approximates an absolute
280
+ disclaimer and waiver of all liability.
281
+
282
+ - Section 6 – Term and Termination.
283
+
284
+ - a. This Public License applies for the term of the Copyright and
285
+ Similar Rights licensed here. However, if You fail to comply
286
+ with this Public License, then Your rights under this Public
287
+ License terminate automatically.
288
+ - b. Where Your right to use the Licensed Material has terminated
289
+ under Section 6(a), it reinstates:
290
+
291
+ - 1. automatically as of the date the violation is cured,
292
+ provided it is cured within 30 days of Your discovery of the
293
+ violation; or
294
+ - 2. upon express reinstatement by the Licensor.
295
+
296
+ For the avoidance of doubt, this Section 6(b) does not affect
297
+ any right the Licensor may have to seek remedies for Your
298
+ violations of this Public License.
299
+
300
+ - c. For the avoidance of doubt, the Licensor may also offer the
301
+ Licensed Material under separate terms or conditions or stop
302
+ distributing the Licensed Material at any time; however, doing
303
+ so will not terminate this Public License.
304
+ - d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
305
+ License.
306
+
307
+ - Section 7 – Other Terms and Conditions.
308
+
309
+ - a. The Licensor shall not be bound by any additional or
310
+ different terms or conditions communicated by You unless
311
+ expressly agreed.
312
+ - b. Any arrangements, understandings, or agreements regarding the
313
+ Licensed Material not stated herein are separate from and
314
+ independent of the terms and conditions of this Public License.
315
+
316
+ - Section 8 – Interpretation.
317
+
318
+ - a. For the avoidance of doubt, this Public License does not, and
319
+ shall not be interpreted to, reduce, limit, restrict, or impose
320
+ conditions on any use of the Licensed Material that could
321
+ lawfully be made without permission under this Public License.
322
+ - b. To the extent possible, if any provision of this Public
323
+ License is deemed unenforceable, it shall be automatically
324
+ reformed to the minimum extent necessary to make it enforceable.
325
+ If the provision cannot be reformed, it shall be severed from
326
+ this Public License without affecting the enforceability of the
327
+ remaining terms and conditions.
328
+ - c. No term or condition of this Public License will be waived
329
+ and no failure to comply consented to unless expressly agreed to
330
+ by the Licensor.
331
+ - d. Nothing in this Public License constitutes or may be
332
+ interpreted as a limitation upon, or waiver of, any privileges
333
+ and immunities that apply to the Licensor or You, including from
334
+ the legal processes of any jurisdiction or authority.
335
+
336
+ Creative Commons is not a party to its public licenses. Notwithstanding,
337
+ Creative Commons may elect to apply one of its public licenses to
338
+ material it publishes and in those instances will be considered the
339
+ "Licensor." The text of the Creative Commons public licenses is
340
+ dedicated to the public domain under the CC0 Public Domain Dedication.
341
+ Except for the limited purpose of indicating that material is shared
342
+ under a Creative Commons public license or as otherwise permitted by the
343
+ Creative Commons policies published at creativecommons.org/policies,
344
+ Creative Commons does not authorize the use of the trademark "Creative
345
+ Commons" or any other trademark or logo of Creative Commons without its
346
+ prior written consent including, without limitation, in connection with
347
+ any unauthorized modifications to any of its public licenses or any
348
+ other arrangements, understandings, or agreements concerning use of
349
+ licensed material. For the avoidance of doubt, this paragraph does not
350
+ form part of the public licenses.
351
+
352
+ Creative Commons may be contacted at creativecommons.org.
ORIGINAL_README.md ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MangaNinja: Line Art Colorization with Precise Reference Following
2
+
3
+ This repository represents the official implementation of the paper titled "MangaNinja: Line Art Colorization with Precise Reference Following".
4
+
5
+ [![Website](docs/badge-website.svg)](https://johanan528.github.io/MangaNinjia/)
6
+ [![Paper](https://img.shields.io/badge/arXiv-PDF-b31b1b)](https://arxiv.org/abs/2501.08332)
7
+ [![License](https://img.shields.io/badge/License-CC%20BY--NC%204.0-929292)](https://creativecommons.org/licenses/by-nc/4.0/)
8
+
9
+ <p align="center">
10
+ <a href="https://johanan528.github.io/"><strong>Zhiheng Liu*</strong></a>
11
+ ·
12
+ <a href="https://felixcheng97.github.io/"><strong>Ka Leong Cheng*</strong></a>
13
+ ·
14
+ <a href="https://xavierchen34.github.io/"><strong>Xi Chen</strong></a>
15
+ ·
16
+ <a href="https://jiexiaou.github.io/"><strong>Jie Xiao</strong></a>
17
+ ·
18
+ <a href="https://ken-ouyang.github.io/"><strong>Hao Ouyang</strong></a>
19
+ ·
20
+ <a href="https://scholar.google.com/citations?user=Mo_2YsgAAAAJ&hl=zh-CN"><strong>Kai Zhu</strong></a>
21
+ ·
22
+ <a href="https://scholar.google.com/citations?user=8zksQb4AAAAJ&hl=zh-CN"><strong>Yu Liu</strong></a>
23
+ ·
24
+ <a href="https://shenyujun.github.io/"><strong>Yujun Shen</strong></a>
25
+ ·
26
+ <a href="https://cqf.io/"><strong>Qifeng Chen</strong></a>
27
+ ·
28
+ <a href="http://luoping.me/"><strong>Ping Luo</strong></a>
29
+ <br>
30
+ </p>
31
+
32
+ We propose **MangaNinja**, a reference-based line art colorization method. MangaNinja
33
+ automatically aligns the reference with the line art for colorization, demonstrating remarkable consistency. Additionally, users can achieve
34
+ more complex tasks using point control. We hope that MangaNinja can accelerate the colorization process in the anime industry.
35
+
36
+ ![teaser](docs/teaser.gif)
37
+ ## 📢 News
38
+ * 2025-01-15: Inference code and paper are released.
39
+ * 2025-01-16: MangaNinja is available on windows, 6G VRAM need Auto install and Download Model. Thanks @sdbds ! You can found it [here](https://github.com/sdbds/MangaNinjia-for-windows). 🔥
40
+ * 🏃: We will open an issue area to investigate user needs and adjust the model accordingly. This includes more memory-efficient structures, data formats for line art (such as binary line art), and considering retraining MangaNinjia on a better foundation model (sd3,flux).
41
+
42
+ ## 🛠️ Setup
43
+
44
+ ### 📦 Repository
45
+
46
+ Clone the repository (requires git):
47
+
48
+ ```bash
49
+ git clone https://github.com/ali-vilab/MangaNinjia.git
50
+ cd MangaNinjia
51
+ ```
52
+
53
+ ### 💻 Dependencies
54
+
55
+ Install with `conda`:
56
+ ```bash
57
+ conda env create -f environment.yaml
58
+ conda activate MangaNinjia
59
+ ```
60
+ ### ⚙️ Weights
61
+ * You could download them from HuggingFace: [StableDiffusion](https://modelscope.cn/models/AI-ModelScope/stable-diffusion-v1-5), [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14), [control_v11p_sd15_lineart](https://huggingface.co/lllyasviel/control_v11p_sd15_lineart) and [Annotators](https://huggingface.co/lllyasviel/Annotators/blob/main/sk_model.pth)
62
+ * You could download our [MangaNinjia model](https://huggingface.co/Johanan0528/MangaNinjia) from HuggingFace
63
+ * The downloaded checkpoint directory has the following structure:
64
+ ```
65
+ -- checkpoints
66
+ |-- StableDiffusion
67
+ |-- models
68
+ |-- clip-vit-large-patch14
69
+ |-- control_v11p_sd15_lineart
70
+ |-- Annotators
71
+ |--sk_model.pth
72
+ |-- MangaNinjia
73
+ |-- denoising_unet.pth
74
+ |-- reference_unet.pth
75
+ |-- point_net.pth
76
+ |-- controlnet.pth
77
+ ```
78
+
79
+
80
+ ## 🎮 Inference
81
+ ```bash
82
+ cd scripts
83
+ bash infer.sh
84
+ ```
85
+
86
+ You can find all results in `output/`. Enjoy!
87
+
88
+ #### 📍 Inference settings
89
+
90
+ The default settings are optimized for the best result. However, the behavior of the code can be customized:
91
+ - `--denoise_steps`: Number of denoising steps of each inference pass. For the original (DDIM) version, it's recommended to use 20-50 steps.
92
+ - `--is_lineart`: If the user provides an image and the task is to color the line art within that image, this parameter is not needed. However, if the input is already a line art and no additional extraction is necessary, then this parameter should be included.
93
+ - `--guidance_scale_ref`: Increasing makes the model more inclined to accept the guidance of the reference image.
94
+ - `--guidance_scale_point`: Increasing makes the model more inclined to input point guidance to achieve more customized colorization.
95
+ - `--point_ref_paths` and `--point_lineart_paths` (**optional**): Two 512x512 matrices are used to represent the matching points between the corresponding reference and line art with continuously increasing integers. That is, the coordinates of the matching points in both matrices will have the same values: 1, 2, 3, etc., while the values in other positions will be 0 (you can refer to the provided samples). Of course, we recommend using Gradio for point guidance.
96
+
97
+ ## 🌱 Gradio
98
+ First, modify `./configs/inference.yaml` to set the path of model weight. Afterwards, run the script:
99
+ ```bash
100
+ python run_gradio.py
101
+ ```
102
+ The gradio demo would look like the UI shown below.
103
+ <table align="center">
104
+ <tr>
105
+ <td>
106
+ <img src="docs/gradio1.png" width="300" height="400">
107
+ </td>
108
+ <td>
109
+ <img src="docs/gradio2.png" width="300" height="400">
110
+ </td>
111
+ </tr>
112
+ </table>
113
+ A biref tutorial:
114
+
115
+ 1. Upload the reference image and target image.
116
+
117
+ Note that for the target image, there are two modes: you can upload an RGB image, and the model will automatically extract the line art; or you can directly upload the line art by checking the 'input is lineart' option.
118
+
119
+ The line art images are single-channel grayscale images, where the input consists of floating-point values with the background set to 0 and the line art close to 1. Additionally, we would like to further communicate with our users: if the line art you commonly use is binarized, please let us know. We will fine-tune the model and release an updated version to better suit your needs. 😆
120
+
121
+ 2. Click 'Process Images' to resize the images to 512*512 resolution.
122
+ 3. (Optional) **Starting from the reference image**, **alternately** click on the reference and target images in sequence to define matching points. Use 'Undo' to revert the last action.
123
+ 4. Click 'Generate' to produce the result.
124
+ ## 🌺 Acknowledgements
125
+ This project is developped on the codebase of [MagicAnimate](https://github.com/magic-research/magic-animate). We appreciate this great work!
126
+
127
+ ## 🎓 Citation
128
+
129
+ Please cite our paper:
130
+
131
+ ```bibtex
132
+ @article{liu2024manganinja,
133
+ author = {Zhiheng Liu and Ka Leong Cheng and Xi Chen and Jie Xiao and Hao Ouyang and Kai Zhu and Yu Liu and Yujun Shen
134
+ and Qifeng Chen and Ping Luo},
135
+ title = {MangaNinja: Line Art Colorization with Precise Reference Following},
136
+ journal = {CoRR},
137
+ volume = {abs/xxxx.xxxxx},
138
+ year = {2024}
139
+ }
140
+ ```
configs/inference.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_path:
2
+ pretrained_model_name_or_path: ./checkpoints/StableDiffusion
3
+ clip_vision_encoder_path: ./checkpoints/models/clip-vit-large-patch14
4
+ controlnet_model_name: './checkpoints/models/control_v11p_sd15_lineart'
5
+ annotator_ckpts_path: ./checkpoints/models/Annotators
6
+ manga_control_model_path: ./checkpoints/MangaNinjia/controlnet.pth
7
+ manga_reference_model_path: ./checkpoints/MangaNinjia/reference_unet.pth
8
+ manga_main_model_path: ./checkpoints/MangaNinjia/denoising_unet.pth
9
+ point_net_path: ./checkpoints/MangaNinjia/point_net.pth
10
+ inference_config:
11
+ output_path: output
12
+ device: cuda
docs/badge-website.svg ADDED
docs/gradio.png ADDED
docs/gradio1.png ADDED
docs/gradio2.png ADDED
docs/teaser.gif ADDED

Git LFS Details

  • SHA256: 2438aa5b2b949a0d2989258ed46415a60d7a30f4b1b472885a8dd3487da17058
  • Pointer size: 133 Bytes
  • Size of remote file: 12.5 MB
docs/teaser.png ADDED
environment.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: MangaNinjia
2
+ channels:
3
+ - defaults
4
+ dependencies:
5
+ - pip=24.2
6
+ - python=3.10.14
7
+ - pip:
8
+ - accelerate==0.31.0
9
+ - diffusers==0.27.2
10
+ - gradio==3.39.0
11
+ - gradio-client==1.3.0
12
+ - h5py==3.11.0
13
+ - huggingface-hub==0.24.6
14
+ - imageio==2.35.1
15
+ - imageio-ffmpeg==0.5.1
16
+ - importlib-metadata==8.4.0
17
+ - importlib-resources==6.4.5
18
+ - ipdb==0.13.13
19
+ - ipython==8.26.0
20
+ - ipywidgets==8.1.5
21
+ - kornia==0.7.3
22
+ - kornia-rs==0.1.5
23
+ - omegaconf==2.3.0
24
+ - opencv-python==4.10.0.84
25
+ - pandas==2.2.2
26
+ - pillow==10.4.0
27
+ - scikit-image==0.24.0
28
+ - scikit-learn==1.5.2
29
+ - scipy==1.14.1
30
+ - torch==2.3.0
31
+ - torchaudio==2.3.0
32
+ - torchmetrics==1.4.1
33
+ - torchvision==0.18.0
34
+ - tqdm==4.66.5
35
+ - transformers==4.44.1
36
+ - einops==0.8.0
37
+ - basicsr==1.3.5
infer.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import os
4
+ import random
5
+ import numpy as np
6
+ import torch
7
+ from PIL import Image
8
+ from tqdm.auto import tqdm
9
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
10
+ import torch.nn as nn
11
+ from inference.manganinjia_pipeline import MangaNinjiaPipeline
12
+ from diffusers import (
13
+ ControlNetModel,
14
+ DiffusionPipeline,
15
+ DDIMScheduler,
16
+ AutoencoderKL,
17
+ )
18
+ from src.models.mutual_self_attention_multi_scale import ReferenceAttentionControl
19
+ from src.models.unet_2d_condition import UNet2DConditionModel
20
+ from src.models.refunet_2d_condition import RefUNet2DConditionModel
21
+ from src.point_network import PointNet
22
+ from src.annotator.lineart import BatchLineartDetector
23
+
24
+ if "__main__" == __name__:
25
+ logging.basicConfig(level=logging.INFO)
26
+
27
+ # -------------------- Arguments --------------------
28
+ parser = argparse.ArgumentParser(
29
+ description="Run single-image MangaNinjia"
30
+ )
31
+ parser.add_argument(
32
+ "--output_dir", type=str, required=True, help="Output directory."
33
+ )
34
+
35
+ # inference setting
36
+ parser.add_argument(
37
+ "--denoise_steps",
38
+ type=int,
39
+ default=50, # quantitative evaluation uses 50 steps
40
+ help="Diffusion denoising steps, more steps results in higher accuracy but slower inference speed.",
41
+ )
42
+
43
+ # resolution setting
44
+ parser.add_argument("--seed", type=int, default=None, help="Random seed.")
45
+
46
+ parser.add_argument(
47
+ "--pretrained_model_name_or_path",
48
+ type=str,
49
+ default=None,
50
+ required=True,
51
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
52
+ )
53
+ parser.add_argument(
54
+ "--image_encoder_path",
55
+ type=str,
56
+ default=None,
57
+ required=True,
58
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
59
+ )
60
+ parser.add_argument(
61
+ "--controlnet_model_name_or_path", type=str, required=True, help="Path to original controlnet."
62
+ )
63
+ parser.add_argument(
64
+ "--annotator_ckpts_path", type=str, required=True, help="Path to depth inpainting model."
65
+ )
66
+ parser.add_argument(
67
+ "--manga_reference_unet_path", type=str, required=True, help="Path to depth inpainting model."
68
+ )
69
+ parser.add_argument(
70
+ "--manga_main_model_path", type=str, required=True, help="Path to depth inpainting model."
71
+ )
72
+ parser.add_argument(
73
+ "--manga_controlnet_model_path", type=str, required=True, help="Path to depth inpainting model."
74
+ )
75
+ parser.add_argument(
76
+ "--point_net_path", type=str, required=True, help="Path to depth inpainting model."
77
+ )
78
+ parser.add_argument(
79
+ "--input_reference_paths",
80
+ nargs='+',
81
+ default=None,
82
+ help="input_image_paths",
83
+ )
84
+ parser.add_argument(
85
+ "--input_lineart_paths",
86
+ nargs='+',
87
+ default=None,
88
+ help="lineart_paths",
89
+ )
90
+ parser.add_argument(
91
+ "--point_ref_paths",
92
+ type=str,
93
+ default=None,
94
+ nargs="+",
95
+ )
96
+ parser.add_argument(
97
+ "--point_lineart_paths",
98
+ type=str,
99
+ default=None,
100
+ nargs="+",
101
+ )
102
+ parser.add_argument(
103
+ "--is_lineart",
104
+ action="store_true",
105
+ default=False
106
+ )
107
+ parser.add_argument(
108
+ "--guidance_scale_ref",
109
+ type=float,
110
+ default=1e-4,
111
+ help="guidance scale for reference image",
112
+ )
113
+ parser.add_argument(
114
+ "--guidance_scale_point",
115
+ type=float,
116
+ default=1e-4,
117
+ help="guidance scale for points",
118
+ )
119
+ args = parser.parse_args()
120
+ output_dir = args.output_dir
121
+ denoise_steps = args.denoise_steps
122
+ seed = args.seed
123
+ is_lineart = args.is_lineart
124
+ os.makedirs(output_dir, exist_ok=True)
125
+ logging.info(f"output dir = {output_dir}")
126
+ if args.input_reference_paths is not None:
127
+ assert len(args.input_reference_paths) == len(args.input_lineart_paths)
128
+ input_reference_paths = args.input_reference_paths
129
+ input_lineart_paths = args.input_lineart_paths
130
+ if args.point_ref_paths is not None:
131
+ point_ref_paths = args.point_ref_paths
132
+ point_lineart_paths = args.point_lineart_paths
133
+ assert len(point_ref_paths) == len(point_lineart_paths)
134
+ print(f"arguments: {args}")
135
+ if seed is None:
136
+ import time
137
+
138
+ seed = int(time.time())
139
+ generator = torch.cuda.manual_seed(seed)
140
+ # -------------------- Device --------------------
141
+ if torch.cuda.is_available():
142
+ device = torch.device("cuda")
143
+ else:
144
+ device = torch.device("cpu")
145
+ logging.warning("CUDA is not available. Running on CPU will be slow.")
146
+ logging.info(f"device = {device}")
147
+
148
+ # -------------------- Model --------------------
149
+ preprocessor = BatchLineartDetector(args.annotator_ckpts_path)
150
+ preprocessor.to(device,dtype=torch.float32)
151
+ in_channels_reference_unet = 4
152
+ in_channels_denoising_unet = 4
153
+ in_channels_controlnet = 4
154
+ noise_scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_name_or_path,subfolder='scheduler')
155
+ vae = AutoencoderKL.from_pretrained(
156
+ args.pretrained_model_name_or_path,
157
+ subfolder='vae'
158
+ )
159
+
160
+ denoising_unet = UNet2DConditionModel.from_pretrained(
161
+ args.pretrained_model_name_or_path,subfolder="unet",
162
+ in_channels=in_channels_denoising_unet,
163
+ low_cpu_mem_usage=False,
164
+ ignore_mismatched_sizes=True
165
+ )
166
+
167
+ reference_unet = RefUNet2DConditionModel.from_pretrained(
168
+ args.pretrained_model_name_or_path,subfolder="unet",
169
+ in_channels=in_channels_reference_unet,
170
+ low_cpu_mem_usage=False,
171
+ ignore_mismatched_sizes=True
172
+ )
173
+ refnet_tokenizer = CLIPTokenizer.from_pretrained(args.image_encoder_path)
174
+ refnet_text_encoder = CLIPTextModel.from_pretrained(args.image_encoder_path)
175
+ refnet_image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path)
176
+
177
+ controlnet = ControlNetModel.from_pretrained(
178
+ args.controlnet_model_name_or_path,
179
+ in_channels=in_channels_controlnet,
180
+ low_cpu_mem_usage=False,
181
+ ignore_mismatched_sizes=True
182
+ )
183
+ controlnet_tokenizer = CLIPTokenizer.from_pretrained(args.image_encoder_path)
184
+ controlnet_text_encoder = CLIPTextModel.from_pretrained(args.image_encoder_path)
185
+ controlnet_image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path)
186
+
187
+
188
+ point_net=PointNet()
189
+
190
+
191
+
192
+ controlnet.load_state_dict(
193
+ torch.load(args.manga_controlnet_model_path, map_location="cpu"),
194
+ strict=False,
195
+ )
196
+ point_net.load_state_dict(
197
+ torch.load(args.point_net_path, map_location="cpu"),
198
+ strict=False,
199
+ )
200
+ reference_unet.load_state_dict(
201
+ torch.load(args.manga_reference_unet_path, map_location="cpu"),
202
+ strict=False,
203
+ )
204
+ denoising_unet.load_state_dict(
205
+ torch.load(args.manga_main_model_path, map_location="cpu"),
206
+ strict=False,
207
+ )
208
+ pipe = MangaNinjiaPipeline(
209
+ reference_unet=reference_unet,
210
+ controlnet=controlnet,
211
+ denoising_unet=denoising_unet,
212
+ vae=vae,
213
+ refnet_tokenizer=refnet_tokenizer,
214
+ refnet_text_encoder=refnet_text_encoder,
215
+ refnet_image_encoder=refnet_image_encoder,
216
+ controlnet_tokenizer=controlnet_tokenizer,
217
+ controlnet_text_encoder=controlnet_text_encoder,
218
+ controlnet_image_encoder=controlnet_image_encoder,
219
+ scheduler=noise_scheduler,
220
+ point_net=point_net
221
+ )
222
+ pipe = pipe.to(torch.device(device))
223
+
224
+ # -------------------- Inference and saving --------------------
225
+ with torch.no_grad():
226
+ for i in range(len(input_reference_paths)):
227
+ input_reference_path = input_reference_paths[i]
228
+ input_lineart_path = input_lineart_paths[i]
229
+
230
+ # save path
231
+ rgb_name_base = os.path.splitext(os.path.basename(input_reference_path))[0]
232
+ pred_name_base = rgb_name_base + "_colorized"
233
+ lineart_name_base = rgb_name_base + "_lineart"
234
+ colored_save_path = os.path.join(
235
+ output_dir, f"{pred_name_base}.png"
236
+ )
237
+ lineart_save_path = os.path.join(
238
+ output_dir, f"{lineart_name_base}.png"
239
+ )
240
+ if point_ref_paths is not None:
241
+ point_ref_path = point_ref_paths[i]
242
+ point_lineart_path = point_lineart_paths[i]
243
+ point_ref = torch.from_numpy(np.load(point_ref_path)).unsqueeze(0).unsqueeze(0)
244
+ point_main = torch.from_numpy(np.load(point_lineart_path)).unsqueeze(0).unsqueeze(0)
245
+ else:
246
+ matrix1 = np.zeros((512, 512), dtype=np.uint8)
247
+ matrix2 = np.zeros((512, 512), dtype=np.uint8)
248
+ point_ref = torch.from_numpy(matrix1).unsqueeze(0).unsqueeze(0)
249
+ point_main = torch.from_numpy(matrix2).unsqueeze(0).unsqueeze(0)
250
+ ref_image = Image.open(input_reference_path)
251
+ ref_image = ref_image.resize((512, 512))
252
+ target_image = Image.open(input_lineart_path)
253
+ target_image = target_image.resize((512, 512))
254
+ pipe_out = pipe(
255
+ is_lineart,
256
+ ref_image,
257
+ target_image,
258
+ target_image,
259
+ denosing_steps=denoise_steps,
260
+ processing_res=512,
261
+ match_input_res=True,
262
+ batch_size=1,
263
+ show_progress_bar=True,
264
+ guidance_scale_ref=args.guidance_scale_ref,
265
+ guidance_scale_point=args.guidance_scale_point,
266
+ preprocessor=preprocessor,
267
+ generator=generator,
268
+ point_ref=point_ref,
269
+ point_main=point_main,
270
+ )
271
+
272
+ if os.path.exists(colored_save_path):
273
+ logging.warning(f"Existing file: '{colored_save_path}' will be overwritten")
274
+ image = pipe_out.img_pil
275
+ lineart = pipe_out.to_save_dict['edge2_black']
276
+ image.save(colored_save_path)
277
+ lineart.save(lineart_save_path)
inference/manganinjia_pipeline.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Any, Dict, Union
3
+ import torchvision.transforms as transforms
4
+ import torch
5
+ from torch.utils.data import DataLoader, TensorDataset
6
+ import numpy as np
7
+ from tqdm.auto import tqdm
8
+ from PIL import Image
9
+ from diffusers import (
10
+ DiffusionPipeline,
11
+ ControlNetModel,
12
+ DDIMScheduler,
13
+ AutoencoderKL,
14
+ )
15
+ from diffusers.utils import BaseOutput
16
+ from transformers import CLIPTextModel, CLIPTokenizer
17
+ from transformers import CLIPImageProcessor
18
+ from transformers import CLIPVisionModelWithProjection
19
+
20
+ from utils.image_util import resize_max_res,chw2hwc
21
+ from src.point_network import PointNet
22
+ from src.models.mutual_self_attention_multi_scale import ReferenceAttentionControl
23
+ from src.models.unet_2d_condition import UNet2DConditionModel
24
+ from src.models.refunet_2d_condition import RefUNet2DConditionModel
25
+
26
+
27
+ class MangaNinjiaPipelineOutput(BaseOutput):
28
+ img_np: np.ndarray
29
+ img_pil: Image.Image
30
+ to_save_dict: dict
31
+
32
+
33
+ class MangaNinjiaPipeline(DiffusionPipeline):
34
+ rgb_latent_scale_factor = 0.18215
35
+
36
+ def __init__(self,
37
+ reference_unet: RefUNet2DConditionModel,
38
+ controlnet: ControlNetModel,
39
+ denoising_unet: UNet2DConditionModel,
40
+ vae: AutoencoderKL,
41
+ refnet_tokenizer: CLIPTokenizer,
42
+ refnet_text_encoder: CLIPTextModel,
43
+ refnet_image_encoder: CLIPVisionModelWithProjection,
44
+ controlnet_tokenizer: CLIPTokenizer,
45
+ controlnet_text_encoder: CLIPTextModel,
46
+ controlnet_image_encoder: CLIPVisionModelWithProjection,
47
+ scheduler: DDIMScheduler,
48
+ point_net: PointNet
49
+ ):
50
+ super().__init__()
51
+
52
+ self.register_modules(
53
+ reference_unet=reference_unet,
54
+ controlnet=controlnet,
55
+ denoising_unet=denoising_unet,
56
+ vae=vae,
57
+ refnet_tokenizer=refnet_tokenizer,
58
+ refnet_text_encoder=refnet_text_encoder,
59
+ refnet_image_encoder=refnet_image_encoder,
60
+ controlnet_tokenizer=controlnet_tokenizer,
61
+ controlnet_text_encoder=controlnet_text_encoder,
62
+ controlnet_image_encoder=controlnet_image_encoder,
63
+ point_net=point_net,
64
+ scheduler=scheduler,
65
+ )
66
+ self.empty_text_embed = None
67
+ self.clip_image_processor = CLIPImageProcessor()
68
+
69
+ @torch.no_grad()
70
+ def __call__(
71
+ self,
72
+ is_lineart: bool,
73
+ ref1: Image.Image,
74
+ raw2: Image.Image,
75
+ edit2: Image.Image,
76
+ denosing_steps: int = 20,
77
+ processing_res: int = 512,
78
+ match_input_res: bool = True,
79
+ batch_size: int = 0,
80
+ show_progress_bar: bool = True,
81
+ guidance_scale_ref: float = 7,
82
+ guidance_scale_point: float = 12,
83
+ preprocessor=None,
84
+ generator=None,
85
+ point_ref=None,
86
+ point_main=None,
87
+ ) -> MangaNinjiaPipelineOutput:
88
+
89
+ device = self.device
90
+
91
+ input_size = raw2.size
92
+ point_ref=point_ref.float().to(device)
93
+ point_main=point_main.float().to(device)
94
+ def img2embeds(img, image_enc):
95
+ clip_image = self.clip_image_processor.preprocess(
96
+ img, return_tensors="pt"
97
+ ).pixel_values
98
+ clip_image_embeds = image_enc(
99
+ clip_image.to(device, dtype=image_enc.dtype)
100
+ ).image_embeds
101
+ encoder_hidden_states = clip_image_embeds.unsqueeze(1)
102
+ return encoder_hidden_states
103
+ if self.reference_unet:
104
+ refnet_encoder_hidden_states = img2embeds(ref1, self.refnet_image_encoder)
105
+ else:
106
+ refnet_encoder_hidden_states = None
107
+ if self.controlnet:
108
+ controlnet_encoder_hidden_states = img2embeds(ref1, self.controlnet_image_encoder)
109
+ else:
110
+ controlnet_encoder_hidden_states = None
111
+
112
+ prompt = ""
113
+ def prompt2embeds(prompt, tokenizer, text_encoder):
114
+ text_inputs = tokenizer(
115
+ prompt,
116
+ padding="do_not_pad",
117
+ max_length=tokenizer.model_max_length,
118
+ truncation=True,
119
+ return_tensors="pt",
120
+ )
121
+ text_input_ids = text_inputs.input_ids.to(device) #[1,2]
122
+ empty_text_embed = text_encoder(text_input_ids)[0].to(self.dtype)
123
+ uncond_encoder_hidden_states = empty_text_embed.repeat((1, 1, 1))[:,0,:].unsqueeze(0)
124
+ return uncond_encoder_hidden_states
125
+ if self.reference_unet:
126
+ refnet_uncond_encoder_hidden_states = prompt2embeds(prompt, self.refnet_tokenizer, self.refnet_text_encoder)
127
+ else:
128
+ refnet_uncond_encoder_hidden_states = None
129
+ if self.controlnet:
130
+ controlnet_uncond_encoder_hidden_states = prompt2embeds(prompt, self.controlnet_tokenizer, self.controlnet_text_encoder)
131
+ else:
132
+ controlnet_uncond_encoder_hidden_states = None
133
+
134
+ do_classifier_free_guidance = guidance_scale_ref > 1.0
135
+
136
+ # adjust the input resolution.
137
+ if not match_input_res:
138
+ assert (
139
+ processing_res is not None
140
+ )," Value Error: `resize_output_back` is only valid with "
141
+
142
+ assert processing_res >= 0
143
+ assert denosing_steps >= 1
144
+
145
+ # --------------- Image Processing ------------------------
146
+ # Resize image
147
+ if processing_res > 0:
148
+ def resize_img(img):
149
+ img = resize_max_res(img, max_edge_resolution=processing_res)
150
+ return img
151
+ ref1 = resize_img(ref1)
152
+ raw2 = resize_img(raw2)
153
+ edit2 = resize_img(edit2)
154
+
155
+ # Normalize image
156
+ def normalize_img(img):
157
+ img = img.convert("RGB")
158
+ img = np.array(img)
159
+
160
+ # Normalize RGB Values.
161
+ rgb = np.transpose(img,(2,0,1))
162
+ rgb_norm = rgb / 255.0 * 2.0 - 1.0
163
+ rgb_norm = torch.from_numpy(rgb_norm).to(self.dtype)
164
+ rgb_norm = rgb_norm.to(device)
165
+ img = rgb_norm
166
+ assert img.min() >= -1.0 and img.max() <= 1.0
167
+ return img
168
+ raw2_real = raw2.convert('L')
169
+ ref1 = normalize_img(ref1)
170
+ raw2 = normalize_img(raw2)
171
+ edit2 = normalize_img(edit2)
172
+ single_rgb_dataset = TensorDataset(ref1[None], raw2[None], edit2[None])
173
+
174
+
175
+ # find the batch size
176
+ if batch_size>0:
177
+ _bs = batch_size
178
+ else:
179
+ _bs = 1
180
+ point_ref=self.point_net(point_ref)
181
+ point_main=self.point_net(point_main)
182
+ single_rgb_loader = DataLoader(single_rgb_dataset,batch_size=_bs,shuffle=False)
183
+
184
+ # classifier guidance
185
+ if do_classifier_free_guidance:
186
+ if self.reference_unet:
187
+ refnet_encoder_hidden_states = torch.cat(
188
+ [refnet_uncond_encoder_hidden_states, refnet_encoder_hidden_states,refnet_encoder_hidden_states], dim=0
189
+ )
190
+ else:
191
+ refnet_encoder_hidden_states = None
192
+
193
+ if self.controlnet:
194
+ controlnet_encoder_hidden_states = torch.cat(
195
+ [controlnet_uncond_encoder_hidden_states, controlnet_encoder_hidden_states,controlnet_encoder_hidden_states], dim=0
196
+ )
197
+ else:
198
+ controlnet_encoder_hidden_states = None
199
+
200
+ if self.reference_unet:
201
+ reference_control_writer = ReferenceAttentionControl(
202
+ self.reference_unet,
203
+ do_classifier_free_guidance=do_classifier_free_guidance,
204
+ mode="write",
205
+ batch_size=batch_size,
206
+ fusion_blocks="full",
207
+ )
208
+ reference_control_reader = ReferenceAttentionControl(
209
+ self.denoising_unet,
210
+ do_classifier_free_guidance=do_classifier_free_guidance,
211
+ mode="read",
212
+ batch_size=batch_size,
213
+ fusion_blocks="full",
214
+ )
215
+ else:
216
+ reference_control_writer = None
217
+ reference_control_reader = None
218
+
219
+ if show_progress_bar:
220
+ iterable_bar = tqdm(
221
+ single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False
222
+ )
223
+ else:
224
+ iterable_bar = single_rgb_loader
225
+
226
+ assert len(iterable_bar) == 1
227
+ for batch in iterable_bar:
228
+ (ref1, raw2, edit2) = batch # here the image is still around 0-1
229
+ if is_lineart:
230
+ raw2 = raw2_real
231
+ img_pred, to_save_dict = self.single_infer(
232
+ is_lineart=is_lineart,
233
+ ref1=ref1,
234
+ raw2=raw2,
235
+ edit2=edit2,
236
+ num_inference_steps=denosing_steps,
237
+ show_pbar=show_progress_bar,
238
+ guidance_scale_ref=guidance_scale_ref,
239
+ guidance_scale_point=guidance_scale_point,
240
+ refnet_encoder_hidden_states=refnet_encoder_hidden_states,
241
+ controlnet_encoder_hidden_states=controlnet_encoder_hidden_states,
242
+ reference_control_writer=reference_control_writer,
243
+ reference_control_reader=reference_control_reader,
244
+ preprocessor=preprocessor,
245
+ generator=generator,
246
+ point_ref=point_ref,
247
+ point_main=point_main
248
+ )
249
+ for k, v in to_save_dict.items():
250
+ if k =='edge2_black':
251
+ to_save_dict[k] = Image.fromarray(
252
+ ((to_save_dict['edge2_black'][:,0].squeeze().detach().cpu().numpy() + 1.) / 2 * 255).astype(np.uint8)
253
+ )
254
+ else:
255
+ try:
256
+ to_save_dict[k] = Image.fromarray(
257
+ chw2hwc(((v.squeeze().detach().cpu().numpy() + 1.) / 2 * 255).astype(np.uint8))
258
+ )
259
+ except:
260
+ import ipdb;ipdb.set_trace()
261
+
262
+ torch.cuda.empty_cache() # clear vram cache for ensembling
263
+
264
+ # ----------------- Post processing -----------------
265
+ # Convert to numpy
266
+ img_pred = img_pred.squeeze().cpu().numpy().astype(np.float32)
267
+ img_pred_np = (((img_pred + 1.) / 2.) * 255).astype(np.uint8)
268
+ img_pred_np = chw2hwc(img_pred_np)
269
+ img_pred_pil = Image.fromarray(img_pred_np)
270
+
271
+ # Resize back to original resolution
272
+ if match_input_res:
273
+ img_pred_pil = img_pred_pil.resize(input_size)
274
+ img_pred_np = np.asarray(img_pred_pil)
275
+
276
+ return MangaNinjiaPipelineOutput(
277
+ img_np=img_pred_np,
278
+ img_pil=img_pred_pil,
279
+ to_save_dict=to_save_dict
280
+ )
281
+
282
+
283
+ def __encode_empty_text(self):
284
+ """
285
+ Encode text embedding for empty prompt
286
+ """
287
+ prompt = ""
288
+ text_inputs = self.tokenizer(
289
+ prompt,
290
+ padding="do_not_pad",
291
+ max_length=self.tokenizer.model_max_length,
292
+ truncation=True,
293
+ return_tensors="pt",
294
+ )
295
+ text_input_ids = text_inputs.input_ids.to(self.text_encoder.device) #[1,2]
296
+ # print(text_input_ids.shape)
297
+ self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype) #[1,2,1024]
298
+
299
+ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):
300
+ # get the original timestep using init_timestep
301
+ if denoising_start is None:
302
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
303
+ t_start = max(num_inference_steps - init_timestep, 0)
304
+ else:
305
+ t_start = 0
306
+
307
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
308
+
309
+ # Strength is irrelevant if we directly request a timestep to start at;
310
+ # that is, strength is determined by the denoising_start instead.
311
+ if denoising_start is not None:
312
+ discrete_timestep_cutoff = int(
313
+ round(
314
+ self.scheduler.config.num_train_timesteps
315
+ - (denoising_start * self.scheduler.config.num_train_timesteps)
316
+ )
317
+ )
318
+ timesteps = list(filter(lambda ts: ts < discrete_timestep_cutoff, timesteps))
319
+ return torch.tensor(timesteps), len(timesteps)
320
+
321
+ return timesteps, num_inference_steps - t_start
322
+
323
+ @torch.no_grad()
324
+ def single_infer(
325
+ self,
326
+ is_lineart: bool,
327
+ ref1: torch.Tensor,
328
+ raw2: torch.Tensor,
329
+ edit2: torch.Tensor,
330
+ num_inference_steps: int,
331
+ show_pbar: bool,
332
+ guidance_scale_ref: float,
333
+ guidance_scale_point: float,
334
+ refnet_encoder_hidden_states: torch.Tensor,
335
+ controlnet_encoder_hidden_states: torch.Tensor,
336
+ reference_control_writer: ReferenceAttentionControl,
337
+ reference_control_reader: ReferenceAttentionControl,
338
+ preprocessor,
339
+ generator,
340
+ point_ref,
341
+ point_main
342
+ ):
343
+ do_classifier_free_guidance = guidance_scale_ref > 1.0
344
+ device = ref1.device
345
+ to_save_dict = {
346
+ 'ref1': ref1,
347
+ }
348
+
349
+ # Set timesteps: inherit from the diffuison pipeline
350
+ self.scheduler.set_timesteps(num_inference_steps, device=device) # here the numbers of the steps is only 10.
351
+ timesteps = self.scheduler.timesteps # [T]
352
+
353
+ # encode image
354
+ ref1_latents = self.encode_RGB(ref1, generator=generator) # 1/8 Resolution with a channel nums of 4.
355
+ edge2_src = raw2
356
+
357
+ timesteps_add,_=self.get_timesteps(num_inference_steps, 1.0, device, denoising_start=None)
358
+ if is_lineart is not True:
359
+ edge2 = preprocessor(edge2_src)
360
+ else:
361
+ gray_image_np = np.array(edge2_src)
362
+ gray_image_np = gray_image_np / 255.0
363
+ edge2 = torch.from_numpy(gray_image_np.astype(np.float32)).unsqueeze(0).unsqueeze(0).cuda()
364
+ edge2[edge2<=0.24]=0
365
+ edge2_black = edge2.repeat(1, 3, 1, 1) * 2 - 1.
366
+ to_save_dict['edge2_black']=edge2_black
367
+
368
+ edge2 = edge2.repeat(1, 3, 1, 1) * 2 - 1.
369
+ to_save_dict['edge2'] = (1-((edge2+1.)/2))*2-1
370
+
371
+ noisy_edit2_latents = torch.randn(
372
+ ref1_latents.shape, device=device, dtype=self.dtype
373
+ ) # [B, 4, H/8, W/8]
374
+
375
+
376
+ # Denoising loop
377
+ if show_pbar:
378
+ iterable = tqdm(
379
+ enumerate(timesteps),
380
+ total=len(timesteps),
381
+ leave=False,
382
+ desc=" " * 4 + "Diffusion denoising",
383
+ )
384
+ else:
385
+ iterable = enumerate(timesteps)
386
+
387
+ for i, t in iterable:
388
+
389
+ refnet_input = ref1_latents
390
+ controlnet_inputs = (noisy_edit2_latents, edge2)
391
+ unet_input = torch.cat([noisy_edit2_latents], dim=1)
392
+
393
+ if i == 0:
394
+ if self.reference_unet:
395
+ self.reference_unet(
396
+ refnet_input.repeat(
397
+ (3 if do_classifier_free_guidance else 1), 1, 1, 1
398
+ ),
399
+ torch.zeros_like(t),
400
+
401
+ encoder_hidden_states=refnet_encoder_hidden_states,
402
+ return_dict=False,
403
+ )
404
+ reference_control_reader.update(reference_control_writer,point_embedding_ref=point_ref,point_embedding_main=point_main)#size不对
405
+
406
+ if self.controlnet:
407
+ noisy_latents, controlnet_cond = controlnet_inputs
408
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
409
+ noisy_latents.repeat(
410
+ (3 if do_classifier_free_guidance else 1), 1, 1, 1
411
+ ),
412
+ t,
413
+ encoder_hidden_states=controlnet_encoder_hidden_states,
414
+ controlnet_cond=controlnet_cond.repeat(
415
+ (3 if do_classifier_free_guidance else 1), 1, 1, 1
416
+ ),
417
+ return_dict=False,
418
+ )
419
+ else:
420
+ down_block_res_samples, mid_block_res_sample = None, None
421
+
422
+ # predict the noise residual
423
+ noise_pred = self.denoising_unet(
424
+ unet_input.repeat(
425
+ (3 if do_classifier_free_guidance else 1), 1, 1, 1
426
+ ).to(dtype=self.denoising_unet.dtype),
427
+ t,
428
+ encoder_hidden_states=refnet_encoder_hidden_states,
429
+ down_block_additional_residuals=down_block_res_samples,
430
+ mid_block_additional_residual=mid_block_res_sample,
431
+ ).sample # [B, 4, h, w]
432
+ noise_pred_uncond, noise_pred_ref, noise_pred_point = noise_pred.chunk(3)
433
+ noise_pred_1 = noise_pred_uncond + guidance_scale_ref * (
434
+ noise_pred_ref - noise_pred_uncond
435
+ )
436
+ noise_pred_2 = noise_pred_ref + guidance_scale_point * (
437
+ noise_pred_point - noise_pred_ref
438
+ )
439
+ noise_pred=(noise_pred_1+noise_pred_2)/2
440
+ noisy_edit2_latents = self.scheduler.step(noise_pred, t, noisy_edit2_latents).prev_sample
441
+
442
+ reference_control_reader.clear()
443
+ reference_control_writer.clear()
444
+ torch.cuda.empty_cache()
445
+
446
+ # clip prediction
447
+ edit2 = self.decode_RGB(noisy_edit2_latents)
448
+ edit2 = torch.clip(edit2, -1.0, 1.0)
449
+
450
+ return edit2, to_save_dict
451
+
452
+
453
+ def encode_RGB(self, rgb_in: torch.Tensor, generator) -> torch.Tensor:
454
+ """
455
+ Encode RGB image into latent.
456
+
457
+ Args:
458
+ rgb_in (`torch.Tensor`):
459
+ Input RGB image to be encoded.
460
+
461
+ Returns:
462
+ `torch.Tensor`: Image latent.
463
+ """
464
+
465
+ # generator = None
466
+ rgb_latent = self.vae.encode(rgb_in).latent_dist.sample(generator)
467
+ rgb_latent = rgb_latent * self.rgb_latent_scale_factor
468
+ return rgb_latent
469
+
470
+ def decode_RGB(self, rgb_latent: torch.Tensor) -> torch.Tensor:
471
+ """
472
+ Decode depth latent into depth map.
473
+
474
+ Args:
475
+ rgb_latent (`torch.Tensor`):
476
+ Depth latent to be decoded.
477
+
478
+ Returns:
479
+ `torch.Tensor`: Decoded depth map.
480
+ """
481
+
482
+ rgb_latent = rgb_latent / self.rgb_latent_scale_factor
483
+ rgb_out = self.vae.decode(rgb_latent, return_dict=False)[0]
484
+ return rgb_out
485
+
486
+
output/hz0_colorized.png ADDED
output/hz0_lineart.png ADDED
output/hz1_colorized.png ADDED
requirements.txt ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.31.0
2
+ diffusers==0.27.2
3
+ gradio==3.39.0
4
+ gradio-client==1.3.0
5
+ h5py==3.11.0
6
+ huggingface-hub==0.24.6
7
+ imageio==2.35.1
8
+ imageio-ffmpeg==0.5.1
9
+ importlib-metadata==8.4.0
10
+ importlib-resources==6.4.5
11
+ ipdb==0.13.13
12
+ ipython==8.26.0
13
+ ipywidgets==8.1.5
14
+ kornia==0.7.3
15
+ kornia-rs==0.1.5
16
+ omegaconf==2.3.0
17
+ opencv-python==4.10.0.84
18
+ pandas==2.2.2
19
+ pillow==10.4.0
20
+ scikit-image==0.24.0
21
+ scikit-learn==1.5.2
22
+ scipy==1.14.1
23
+ torch==2.3.0
24
+ torchaudio==2.3.0
25
+ torchmetrics==1.4.1
26
+ torchvision==0.18.0
27
+ tqdm==4.66.5
28
+ transformers==4.44.1
29
+ einops==0.8.0
30
+ basicsr==1.3.5
run_gradio.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from PIL import Image, ImageDraw
4
+ import cv2
5
+ import gradio as gr
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from omegaconf import OmegaConf
9
+ import numpy as np
10
+ import os
11
+ import re
12
+ from PIL import Image, ImageDraw
13
+ import cv2
14
+ #
15
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
16
+ import torch.nn as nn
17
+ from inference.manganinjia_pipeline import MangaNinjiaPipeline
18
+ from diffusers import (
19
+ ControlNetModel,
20
+ DiffusionPipeline,
21
+ DDIMScheduler,
22
+ AutoencoderKL,
23
+ )
24
+ from src.models.mutual_self_attention_multi_scale import ReferenceAttentionControl
25
+ from src.models.unet_2d_condition import UNet2DConditionModel
26
+ from src.models.refunet_2d_condition import RefUNet2DConditionModel
27
+ from src.point_network import PointNet
28
+ from src.annotator.lineart import BatchLineartDetector
29
+ val_configs = OmegaConf.load('./configs/inference.yaml')
30
+ # === load the checkpoint ===
31
+ pretrained_model_name_or_path = val_configs.model_path.pretrained_model_name_or_path
32
+ refnet_clip_vision_encoder_path = val_configs.model_path.clip_vision_encoder_path
33
+ controlnet_clip_vision_encoder_path = val_configs.model_path.clip_vision_encoder_path
34
+ controlnet_model_name_or_path = val_configs.model_path.controlnet_model_name
35
+ annotator_ckpts_path = val_configs.model_path.annotator_ckpts_path
36
+
37
+ output_root = val_configs.inference_config.output_path
38
+ device = val_configs.inference_config.device
39
+ preprocessor = BatchLineartDetector(annotator_ckpts_path)
40
+ in_channels_reference_unet = 4
41
+ in_channels_denoising_unet = 4
42
+ in_channels_controlnet = 4
43
+ noise_scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path,subfolder='scheduler')
44
+ vae = AutoencoderKL.from_pretrained(
45
+ pretrained_model_name_or_path,
46
+ subfolder='vae'
47
+ )
48
+
49
+ denoising_unet = UNet2DConditionModel.from_pretrained(
50
+ pretrained_model_name_or_path,subfolder="unet",
51
+ in_channels=in_channels_denoising_unet,
52
+ low_cpu_mem_usage=False,
53
+ ignore_mismatched_sizes=True
54
+ )
55
+
56
+ reference_unet = RefUNet2DConditionModel.from_pretrained(
57
+ pretrained_model_name_or_path,subfolder="unet",
58
+ in_channels=in_channels_reference_unet,
59
+ low_cpu_mem_usage=False,
60
+ ignore_mismatched_sizes=True
61
+ )
62
+ refnet_tokenizer = CLIPTokenizer.from_pretrained(refnet_clip_vision_encoder_path)
63
+ refnet_text_encoder = CLIPTextModel.from_pretrained(refnet_clip_vision_encoder_path)
64
+ refnet_image_enc = CLIPVisionModelWithProjection.from_pretrained(refnet_clip_vision_encoder_path)
65
+
66
+ controlnet = ControlNetModel.from_pretrained(
67
+ controlnet_model_name_or_path,
68
+ in_channels=in_channels_controlnet,
69
+ low_cpu_mem_usage=False,
70
+ ignore_mismatched_sizes=True
71
+ )
72
+ controlnet_tokenizer = CLIPTokenizer.from_pretrained(controlnet_clip_vision_encoder_path)
73
+ controlnet_text_encoder = CLIPTextModel.from_pretrained(controlnet_clip_vision_encoder_path)
74
+ controlnet_image_enc = CLIPVisionModelWithProjection.from_pretrained(controlnet_clip_vision_encoder_path)
75
+
76
+
77
+ point_net=PointNet()
78
+ reference_control_writer = ReferenceAttentionControl(
79
+ reference_unet,
80
+ do_classifier_free_guidance=False,
81
+ mode="write",
82
+ fusion_blocks="full",
83
+ )
84
+ reference_control_reader = ReferenceAttentionControl(
85
+ denoising_unet,
86
+ do_classifier_free_guidance=False,
87
+ mode="read",
88
+ fusion_blocks="full",
89
+ )
90
+
91
+
92
+
93
+ controlnet.load_state_dict(
94
+ torch.load(val_configs.model_path.manga_control_model_path, map_location="cpu"),
95
+ strict=False,
96
+ )
97
+ point_net.load_state_dict(
98
+ torch.load(val_configs.model_path.point_net_path, map_location="cpu"),
99
+ strict=False,
100
+ )
101
+ reference_unet.load_state_dict(
102
+ torch.load(val_configs.model_path.manga_reference_model_path, map_location="cpu"),
103
+ strict=False,
104
+ )
105
+ denoising_unet.load_state_dict(
106
+ torch.load(val_configs.model_path.manga_main_model_path, map_location="cpu"),
107
+ strict=False,
108
+ )
109
+ pipe = MangaNinjiaPipeline(
110
+ reference_unet=reference_unet,
111
+ controlnet=controlnet,
112
+ denoising_unet=denoising_unet,
113
+ vae=vae,
114
+ refnet_tokenizer=refnet_tokenizer,
115
+ refnet_text_encoder=refnet_text_encoder,
116
+ refnet_image_encoder=refnet_image_enc,
117
+ controlnet_tokenizer=controlnet_tokenizer,
118
+ controlnet_text_encoder=controlnet_text_encoder,
119
+ controlnet_image_encoder=controlnet_image_enc,
120
+ scheduler=noise_scheduler,
121
+ point_net=point_net
122
+ )
123
+ pipe = pipe.to(torch.device(device))
124
+ def string_to_np_array(coord_string):
125
+ coord_string = coord_string.strip('[]')
126
+ coords = re.findall(r'\d+', coord_string)
127
+ coords = list(map(int, coords))
128
+ coord_array = np.array(coords).reshape(-1, 2)
129
+ return coord_array
130
+ def infer_single(is_lineart, ref_image, target_image, output_coords_ref, output_coords_base, seed = -1, num_inference_steps=20, guidance_scale_ref = 9, guidance_scale_point =15 ):
131
+ """
132
+ mask: 0/1 1-channel np.array
133
+ image: rgb np.array
134
+ """
135
+ generator = torch.cuda.manual_seed(seed)
136
+ matrix1 = np.zeros((512, 512), dtype=np.uint8)
137
+ matrix2 = np.zeros((512, 512), dtype=np.uint8)
138
+ output_coords_ref = string_to_np_array(output_coords_ref)
139
+ output_coords_base = string_to_np_array(output_coords_base)
140
+ for index, (coords_ref,coords_base) in enumerate(zip(output_coords_ref,output_coords_base)):
141
+ y1, x1 = coords_ref
142
+ y2, x2 = coords_base
143
+ matrix1[y1, x1] = index + 1
144
+ matrix2[y2, x2] = index + 1
145
+ point_ref = torch.from_numpy(matrix1).unsqueeze(0).unsqueeze(0)
146
+ point_main = torch.from_numpy(matrix2).unsqueeze(0).unsqueeze(0)
147
+ preprocessor.to(device,dtype=torch.float32)
148
+ pipe_out = pipe(
149
+ is_lineart,
150
+ ref_image,
151
+ target_image,
152
+ target_image,
153
+ denosing_steps=num_inference_steps,
154
+ processing_res=512,
155
+ match_input_res=True,
156
+ batch_size=1,
157
+ show_progress_bar=True,
158
+ guidance_scale_ref=guidance_scale_ref,
159
+ guidance_scale_point=guidance_scale_point,
160
+ preprocessor=preprocessor,
161
+ generator=generator,
162
+ point_ref=point_ref,
163
+ point_main=point_main,
164
+ )
165
+ return pipe_out
166
+
167
+
168
+ def inference_single_image(ref_image,
169
+ tar_image,
170
+ ddim_steps,
171
+ scale_ref,
172
+ scale_point,
173
+ seed,
174
+ output_coords1,
175
+ output_coords2,
176
+ is_lineart
177
+ ):
178
+ if seed == -1:
179
+ seed = np.random.randint(10000)
180
+ pipe_out = infer_single(is_lineart, ref_image, tar_image, output_coords_ref=output_coords1, output_coords_base=output_coords2,seed=seed ,num_inference_steps=ddim_steps, guidance_scale_ref = scale_ref, guidance_scale_point = scale_point
181
+ )
182
+ return pipe_out
183
+ clicked_points_img1 = []
184
+ clicked_points_img2 = []
185
+ current_img_idx = 0
186
+ max_clicks = 14
187
+ point_size = 8
188
+ colors = [(255, 0, 0), (0, 255, 0)]
189
+
190
+ # Process images: resizing them to 512x512
191
+ def process_image(ref, base):
192
+ ref_resized = cv2.resize(ref, (512, 512)) # Note OpenCV resize order is (width, height)
193
+ base_resized = cv2.resize(base, (512, 512))
194
+ return ref_resized, base_resized
195
+
196
+ # Convert string to numpy array of coordinates
197
+ def string_to_np_array(coord_string):
198
+ coord_string = coord_string.strip('[]')
199
+ coords = re.findall(r'\d+', coord_string)
200
+ coords = list(map(int, coords))
201
+ coord_array = np.array(coords).reshape(-1, 2)
202
+ return coord_array
203
+
204
+ # Function to handle click events
205
+ def get_select_coords(img1, img2, evt: gr.SelectData):
206
+ global clicked_points_img1, clicked_points_img2, current_img_idx
207
+ click_coords = (evt.index[1], evt.index[0])
208
+
209
+ if current_img_idx == 0:
210
+ clicked_points_img1.append(click_coords)
211
+ if len(clicked_points_img1) > max_clicks:
212
+ clicked_points_img1 = []
213
+ current_img = img1
214
+ clicked_points = clicked_points_img1
215
+ else:
216
+ clicked_points_img2.append(click_coords)
217
+ if len(clicked_points_img2) > max_clicks:
218
+ clicked_points_img2 = []
219
+ current_img = img2
220
+ clicked_points = clicked_points_img2
221
+
222
+ current_img_idx = 1 - current_img_idx
223
+ img_pil = Image.fromarray(current_img.astype('uint8'))
224
+ draw = ImageDraw.Draw(img_pil)
225
+ for idx, point in enumerate(clicked_points):
226
+ x, y = point
227
+ color = colors[current_img_idx]
228
+ for dx in range(-point_size, point_size + 1):
229
+ for dy in range(-point_size, point_size + 1):
230
+ if 0 <= y + dy < img_pil.size[0] and 0 <= x + dx < img_pil.size[1]:
231
+ draw.point((y+dy, x+dx), fill=color)
232
+
233
+ img_out = np.array(img_pil)
234
+ coord_array = np.array([(x, y) for x, y in clicked_points])
235
+ return img_out, coord_array
236
+
237
+ # Function to clear the clicked points
238
+ def undo_last_point(ref, base):
239
+ global clicked_points_img1, clicked_points_img2, current_img_idx
240
+ current_img_idx=1-current_img_idx
241
+ if current_img_idx == 0 and clicked_points_img1:
242
+ clicked_points_img1.pop() # Undo last point in ref
243
+ elif current_img_idx == 1 and clicked_points_img2:
244
+ clicked_points_img2.pop() # Undo last point in base
245
+
246
+ # After removing the last point, redraw the image without it
247
+ if current_img_idx == 0:
248
+ current_img = ref
249
+ current_img_other = base
250
+ clicked_points = clicked_points_img1
251
+ clicked_points_other = clicked_points_img2
252
+ else:
253
+ current_img = base
254
+ current_img_other = ref
255
+ clicked_points = clicked_points_img2
256
+ clicked_points_other = clicked_points_img1
257
+
258
+ # Redraw the image without the last point
259
+ img_pil = Image.fromarray(current_img.astype('uint8'))
260
+ draw = ImageDraw.Draw(img_pil)
261
+ for idx, point in enumerate(clicked_points):
262
+ x, y = point
263
+ color = colors[current_img_idx]
264
+ for dx in range(-point_size, point_size + 1):
265
+ for dy in range(-point_size, point_size + 1):
266
+ if 0 <= y + dy < img_pil.size[0] and 0 <= x + dx < img_pil.size[1]:
267
+ draw.point((y+dy, x+dx), fill=color)
268
+ img_out = np.array(img_pil)
269
+
270
+
271
+ img_pil_other = Image.fromarray(current_img_other.astype('uint8'),)
272
+ draw_other = ImageDraw.Draw(img_pil_other)
273
+ for idx, point in enumerate(clicked_points_other):
274
+ x, y = point
275
+ color = colors[1-current_img_idx]
276
+ for dx in range(-point_size, point_size + 1):
277
+ for dy in range(-point_size, point_size + 1):
278
+ if 0 <= y + dy < img_pil.size[0] and 0 <= x + dx < img_pil.size[1]:
279
+ draw_other.point((y+dy, x+dx), fill=color)
280
+ img_out_other = np.array(img_pil_other)
281
+
282
+ coord_array = np.array([(x, y) for x, y in clicked_points])
283
+ # Return the updated image and coordinates as text
284
+ updated_coords = str(coord_array.tolist())
285
+
286
+ # If current_img_idx is 0, it means we are working with ref, so return for ref
287
+ if current_img_idx == 0:
288
+ coord_array2 = np.array([(x, y) for x, y in clicked_points_img2])
289
+ updated_coords2 = str(coord_array2.tolist())
290
+ return img_out, updated_coords, img_out_other, updated_coords2 # for ref image
291
+ else:
292
+ coord_array1 = np.array([(x, y) for x, y in clicked_points_img1])
293
+ updated_coords1 = str(coord_array1.tolist())
294
+ return img_out_other, updated_coords1, img_out, updated_coords # for base image
295
+
296
+
297
+ # Main function to run the image processing
298
+ def run_local(ref, base, *args):
299
+ image = Image.fromarray(base)
300
+ ref_image = Image.fromarray(ref)
301
+
302
+ pipe_out = inference_single_image(ref_image.copy(), image.copy(), *args)
303
+ to_save_dict = pipe_out.to_save_dict
304
+ to_save_dict['edit2'] = pipe_out.img_pil
305
+ return [to_save_dict['edit2'], to_save_dict['edge2_black']]
306
+
307
+ with gr.Blocks() as demo:
308
+ with gr.Column():
309
+ gr.Markdown("# MangaNinja: Line Art Colorization with Precise Reference Following")
310
+
311
+ with gr.Row():
312
+ baseline_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", columns=1, height=768)
313
+
314
+ with gr.Accordion("Advanced Option", open=True):
315
+ num_samples = 1
316
+ ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=50, step=1)
317
+ scale_ref = gr.Slider(label="Guidance of ref", minimum=0, maximum=30.0, value=9, step=0.1)
318
+ scale_point = gr.Slider(label="Guidance of points", minimum=0, maximum=30.0, value=15, step=0.1)
319
+ is_lineart = gr.Checkbox(label="Input is lineart", value=False)
320
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=999999999, step=1, value=-1)
321
+
322
+ gr.Markdown("### Tutorial")
323
+ gr.Markdown("1. Upload the reference image and target image. Note that for the target image, there are two modes: you can upload an RGB image, and the model will automatically extract the line art; or you can directly upload the line art by checking the 'input is lineart' option.")
324
+ gr.Markdown("2. Click 'Process Images' to resize the images to 512*512 resolution.")
325
+ gr.Markdown("3. (Optional) **Starting from the reference image**, **alternately** click on the reference and target images in sequence to define matching points. Use 'Undo' to revert the last action.")
326
+ gr.Markdown("4. Click 'Generate' to produce the result.")
327
+ gr.Markdown("# Upload the reference image and target image")
328
+
329
+ with gr.Row():
330
+ ref = gr.Image(label="Reference Image",)
331
+ base = gr.Image(label="Target Image",)
332
+ gr.Button("Process Images").click(process_image, inputs=[ref, base], outputs=[ref, base])
333
+
334
+ with gr.Row():
335
+ output_img1 = gr.Image(label="Reference Output")
336
+ output_coords1 = gr.Textbox(lines=2, label="Clicked Coordinates Image 1 (npy format)")
337
+ output_img2 = gr.Image(label="Base Output")
338
+ output_coords2 = gr.Textbox(lines=2, label="Clicked Coordinates Image 2 (npy format)")
339
+
340
+ # Image click select functions
341
+ ref.select(get_select_coords, [ref, base], [output_img1, output_coords1])
342
+ base.select(get_select_coords, [ref, base], [output_img2, output_coords2])
343
+
344
+ # Undo button
345
+ undo_button = gr.Button("Undo")
346
+ undo_button.click(undo_last_point, inputs=[ref, base], outputs=[output_img1, output_coords1, output_img2, output_coords2])
347
+
348
+ run_local_button = gr.Button(label="Generate", value="Generate")
349
+
350
+ with gr.Row():
351
+ gr.Examples(
352
+ examples=[
353
+ ['test_cases/hz0.png', 'test_cases/hz1.png'],
354
+ ['test_cases/more_cases/az0.png', 'test_cases/more_cases/az1.JPG'],
355
+ ['test_cases/more_cases/hi0.png', 'test_cases/more_cases/hi1.jpg'],
356
+ ['test_cases/more_cases/kn0.jpg', 'test_cases/more_cases/kn1.jpg'],
357
+ ['test_cases/more_cases/rk0.jpg', 'test_cases/more_cases/rk1.jpg'],
358
+
359
+
360
+ ],
361
+ inputs=[ref, base],
362
+ cache_examples=False,
363
+ examples_per_page=100
364
+ )
365
+
366
+
367
+ run_local_button.click(fn=run_local,
368
+ inputs=[ref,
369
+ base,
370
+ ddim_steps,
371
+ scale_ref,
372
+ scale_point,
373
+ seed,
374
+ output_coords1,
375
+ output_coords2,
376
+ is_lineart
377
+ ],
378
+ outputs=[baseline_gallery]
379
+ )
380
+
381
+ demo.launch(server_name="0.0.0.0")
scripts/infer.sh ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -e
3
+ set -x
4
+ pretrained_model_name_or_path='./checkpoints/StableDiffusion'
5
+ image_encoder_path='./checkpoints/models/clip-vit-large-patch14'
6
+ controlnet_model_name_or_path='./checkpoints/models/control_v11p_sd15_lineart'
7
+ annotator_ckpts_path='./checkpoints/models/Annotators'
8
+
9
+ manga_reference_unet_path='./checkpoints/MangaNinjia/reference_unet.pth'
10
+ manga_main_model_path='./checkpoints/MangaNinjia/denoising_unet.pth'
11
+ manga_controlnet_model_path='./checkpoints/MangaNinjia/controlnet.pth'
12
+ point_net_path='./checkpoints/MangaNinjia/point_net.pth'
13
+ export CUDA_VISIBLE_DEVICES=0
14
+
15
+ input_reference_paths='./test_cases/hz0.png ./test_cases/hz1.png'
16
+ input_lineart_paths='./test_cases/hz1.png ./test_cases/hz0.png'
17
+ point_ref_paths='./test_cases/hz01_0.npy ./test_cases/hz01_1.npy'
18
+ point_lineart_paths='./test_cases/hz01_1.npy ./test_cases/hz01_0.npy'
19
+ cd ..
20
+ python infer.py \
21
+ --seed 0 \
22
+ --denoise_steps 50 \
23
+ --pretrained_model_name_or_path $pretrained_model_name_or_path --image_encoder_path $image_encoder_path \
24
+ --controlnet_model_name_or_path $controlnet_model_name_or_path --annotator_ckpts_path $annotator_ckpts_path \
25
+ --manga_reference_unet_path $manga_reference_unet_path --manga_main_model_path $manga_main_model_path \
26
+ --manga_controlnet_model_path $manga_controlnet_model_path --point_net_path $point_net_path \
27
+ --output_dir 'output' \
28
+ --guidance_scale_ref 9 \
29
+ --guidance_scale_point 15 \
30
+ --input_reference_paths $input_reference_paths \
31
+ --input_lineart_paths $input_lineart_paths \
32
+ --point_ref_paths $point_ref_paths \
33
+ --point_lineart_paths $point_lineart_paths \
src/annotator/lineart/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 Caroline Chan
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
src/annotator/lineart/__init__.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # From https://github.com/carolineec/informative-drawings
2
+ # MIT License
3
+
4
+ import os
5
+ import cv2
6
+ import torch
7
+ import numpy as np
8
+
9
+ import torch.nn as nn
10
+ from einops import rearrange
11
+
12
+
13
+ norm_layer = nn.InstanceNorm2d
14
+
15
+
16
+ class ResidualBlock(nn.Module):
17
+ def __init__(self, in_features):
18
+ super(ResidualBlock, self).__init__()
19
+
20
+ conv_block = [ nn.ReflectionPad2d(1),
21
+ nn.Conv2d(in_features, in_features, 3),
22
+ norm_layer(in_features),
23
+ nn.ReLU(inplace=True),
24
+ nn.ReflectionPad2d(1),
25
+ nn.Conv2d(in_features, in_features, 3),
26
+ norm_layer(in_features)
27
+ ]
28
+
29
+ self.conv_block = nn.Sequential(*conv_block)
30
+
31
+ def forward(self, x):
32
+ return x + self.conv_block(x)
33
+
34
+
35
+ class Generator(nn.Module):
36
+ def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True):
37
+ super(Generator, self).__init__()
38
+
39
+ # Initial convolution block
40
+ model0 = [ nn.ReflectionPad2d(3),
41
+ nn.Conv2d(input_nc, 64, 7),
42
+ norm_layer(64),
43
+ nn.ReLU(inplace=True) ]
44
+ self.model0 = nn.Sequential(*model0)
45
+
46
+ # Downsampling
47
+ model1 = []
48
+ in_features = 64
49
+ out_features = in_features*2
50
+ for _ in range(2):
51
+ model1 += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
52
+ norm_layer(out_features),
53
+ nn.ReLU(inplace=True) ]
54
+ in_features = out_features
55
+ out_features = in_features*2
56
+ self.model1 = nn.Sequential(*model1)
57
+
58
+ model2 = []
59
+ # Residual blocks
60
+ for _ in range(n_residual_blocks):
61
+ model2 += [ResidualBlock(in_features)]
62
+ self.model2 = nn.Sequential(*model2)
63
+
64
+ # Upsampling
65
+ model3 = []
66
+ out_features = in_features//2
67
+ for _ in range(2):
68
+ model3 += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
69
+ norm_layer(out_features),
70
+ nn.ReLU(inplace=True) ]
71
+ in_features = out_features
72
+ out_features = in_features//2
73
+ self.model3 = nn.Sequential(*model3)
74
+
75
+ # Output layer
76
+ model4 = [ nn.ReflectionPad2d(3),
77
+ nn.Conv2d(64, output_nc, 7)]
78
+ if sigmoid:
79
+ model4 += [nn.Sigmoid()]
80
+
81
+ self.model4 = nn.Sequential(*model4)
82
+
83
+ def forward(self, x, cond=None):
84
+ out = self.model0(x)
85
+ out = self.model1(out)
86
+ out = self.model2(out)
87
+ out = self.model3(out)
88
+ out = self.model4(out)
89
+
90
+ return out
91
+
92
+
93
+ class LineartDetector:
94
+ def __init__(self, annotator_ckpts_path):
95
+ self.annotator_ckpts_path = annotator_ckpts_path
96
+ self.model = self.load_model('sk_model.pth')
97
+ self.model_coarse = self.load_model('sk_model2.pth')
98
+
99
+ def load_model(self, name):
100
+ remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/" + name
101
+ modelpath = os.path.join(self.annotator_ckpts_path, name)
102
+ if not os.path.exists(modelpath):
103
+ from basicsr.utils.download_util import load_file_from_url
104
+ load_file_from_url(remote_model_path, model_dir=self.annotator_ckpts_path)
105
+ model = Generator(3, 1, 3)
106
+ model.load_state_dict(torch.load(modelpath, map_location=torch.device('cpu')))
107
+ model.eval()
108
+ model = model.cuda()
109
+ return model
110
+
111
+ def __call__(self, input_image, coarse):
112
+ model = self.model_coarse if coarse else self.model
113
+ assert input_image.ndim == 3
114
+ image = input_image
115
+ with torch.no_grad():
116
+ image = torch.from_numpy(image).float().cuda()
117
+ image = image / 255.0
118
+ image = rearrange(image, 'h w c -> 1 c h w')
119
+ line = model(image)[0][0]
120
+
121
+ line = line.cpu().numpy()
122
+ line = (line * 255.0).clip(0, 255).astype(np.uint8)
123
+
124
+ return line
125
+
126
+ class BatchLineartDetector:
127
+ def __init__(self, annotator_ckpts_path):
128
+ self.annotator_ckpts_path = annotator_ckpts_path
129
+ self.model = self.load_model('sk_model.pth')
130
+
131
+ def load_model(self, name):
132
+ remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/" + name
133
+ modelpath = os.path.join(self.annotator_ckpts_path, name)
134
+ if not os.path.exists(modelpath):
135
+ from basicsr.utils.download_util import load_file_from_url
136
+ load_file_from_url(remote_model_path, model_dir=self.annotator_ckpts_path)
137
+ model = Generator(3, 1, 3)
138
+ model.load_state_dict(torch.load(modelpath, map_location=torch.device('cpu')))
139
+ model.eval()
140
+ return model
141
+
142
+ def to(self, device, dtype):
143
+ self.model.to(device, dtype=dtype)
144
+
145
+ def __call__(self, input_image, mean=-1., std=2.):
146
+ model = self.model
147
+ image = input_image
148
+ with torch.no_grad():
149
+ image = (image - mean) / std
150
+ line = model(image)
151
+ line = 1 - line
152
+ return line
src/models/attention.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
2
+
3
+ from typing import Any, Dict, Optional
4
+
5
+ import torch
6
+ from diffusers.models.attention import AdaLayerNorm, FeedForward
7
+ from src.models.attention_processor import Attention
8
+ from diffusers.models.embeddings import SinusoidalPositionalEmbedding
9
+ from einops import rearrange
10
+ from torch import nn
11
+
12
+
13
+ class BasicTransformerBlock(nn.Module):
14
+ r"""
15
+ A basic Transformer block.
16
+
17
+ Parameters:
18
+ dim (`int`): The number of channels in the input and output.
19
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
20
+ attention_head_dim (`int`): The number of channels in each head.
21
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
22
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
23
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
24
+ num_embeds_ada_norm (:
25
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
26
+ attention_bias (:
27
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
28
+ only_cross_attention (`bool`, *optional*):
29
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
30
+ double_self_attention (`bool`, *optional*):
31
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
32
+ upcast_attention (`bool`, *optional*):
33
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
34
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
35
+ Whether to use learnable elementwise affine parameters for normalization.
36
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
37
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
38
+ final_dropout (`bool` *optional*, defaults to False):
39
+ Whether to apply a final dropout after the last feed-forward layer.
40
+ attention_type (`str`, *optional*, defaults to `"default"`):
41
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
42
+ positional_embeddings (`str`, *optional*, defaults to `None`):
43
+ The type of positional embeddings to apply to.
44
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
45
+ The maximum number of positional embeddings to apply.
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ dim: int,
51
+ num_attention_heads: int,
52
+ attention_head_dim: int,
53
+ dropout=0.0,
54
+ cross_attention_dim: Optional[int] = None,
55
+ activation_fn: str = "geglu",
56
+ num_embeds_ada_norm: Optional[int] = None,
57
+ attention_bias: bool = False,
58
+ only_cross_attention: bool = False,
59
+ double_self_attention: bool = False,
60
+ upcast_attention: bool = False,
61
+ norm_elementwise_affine: bool = True,
62
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
63
+ norm_eps: float = 1e-5,
64
+ final_dropout: bool = False,
65
+ attention_type: str = "default",
66
+ positional_embeddings: Optional[str] = None,
67
+ num_positional_embeddings: Optional[int] = None,
68
+ ):
69
+ super().__init__()
70
+ self.only_cross_attention = only_cross_attention
71
+
72
+ self.use_ada_layer_norm_zero = (
73
+ num_embeds_ada_norm is not None
74
+ ) and norm_type == "ada_norm_zero"
75
+ self.use_ada_layer_norm = (
76
+ num_embeds_ada_norm is not None
77
+ ) and norm_type == "ada_norm"
78
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
79
+ self.use_layer_norm = norm_type == "layer_norm"
80
+
81
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
82
+ raise ValueError(
83
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
84
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
85
+ )
86
+
87
+ if positional_embeddings and (num_positional_embeddings is None):
88
+ raise ValueError(
89
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
90
+ )
91
+
92
+ if positional_embeddings == "sinusoidal":
93
+ self.pos_embed = SinusoidalPositionalEmbedding(
94
+ dim, max_seq_length=num_positional_embeddings
95
+ )
96
+ else:
97
+ self.pos_embed = None
98
+
99
+ # Define 3 blocks. Each block has its own normalization layer.
100
+ # 1. Self-Attn
101
+ if self.use_ada_layer_norm:
102
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
103
+ elif self.use_ada_layer_norm_zero:
104
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
105
+ else:
106
+ self.norm1 = nn.LayerNorm(
107
+ dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
108
+ )
109
+
110
+ self.attn1 = Attention(
111
+ query_dim=dim,
112
+ heads=num_attention_heads,
113
+ dim_head=attention_head_dim,
114
+ dropout=dropout,
115
+ bias=attention_bias,
116
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
117
+ upcast_attention=upcast_attention,
118
+ )
119
+
120
+ # 2. Cross-Attn
121
+ if cross_attention_dim is not None or double_self_attention:
122
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
123
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
124
+ # the second cross attention block.
125
+ self.norm2 = (
126
+ AdaLayerNorm(dim, num_embeds_ada_norm)
127
+ if self.use_ada_layer_norm
128
+ else nn.LayerNorm(
129
+ dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
130
+ )
131
+ )
132
+ self.attn2 = Attention(
133
+ query_dim=dim,
134
+ cross_attention_dim=cross_attention_dim
135
+ if not double_self_attention
136
+ else None,
137
+ heads=num_attention_heads,
138
+ dim_head=attention_head_dim,
139
+ dropout=dropout,
140
+ bias=attention_bias,
141
+ upcast_attention=upcast_attention,
142
+ ) # is self-attn if encoder_hidden_states is none
143
+ else:
144
+ self.norm2 = None
145
+ self.attn2 = None
146
+
147
+ # 3. Feed-forward
148
+ if not self.use_ada_layer_norm_single:
149
+ self.norm3 = nn.LayerNorm(
150
+ dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
151
+ )
152
+
153
+ self.ff = FeedForward(
154
+ dim,
155
+ dropout=dropout,
156
+ activation_fn=activation_fn,
157
+ final_dropout=final_dropout,
158
+ )
159
+
160
+ # 4. Fuser
161
+ if attention_type == "gated" or attention_type == "gated-text-image":
162
+ self.fuser = GatedSelfAttentionDense(
163
+ dim, cross_attention_dim, num_attention_heads, attention_head_dim
164
+ )
165
+
166
+ # 5. Scale-shift for PixArt-Alpha.
167
+ if self.use_ada_layer_norm_single:
168
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
169
+
170
+ # let chunk size default to None
171
+ self._chunk_size = None
172
+ self._chunk_dim = 0
173
+
174
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
175
+ # Sets chunk feed-forward
176
+ self._chunk_size = chunk_size
177
+ self._chunk_dim = dim
178
+
179
+ def forward(
180
+ self,
181
+ hidden_states: torch.FloatTensor,
182
+ attention_mask: Optional[torch.FloatTensor] = None,
183
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
184
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
185
+ timestep: Optional[torch.LongTensor] = None,
186
+ cross_attention_kwargs: Dict[str, Any] = None,
187
+ class_labels: Optional[torch.LongTensor] = None,
188
+ ) -> torch.FloatTensor:
189
+ # Notice that normalization is always applied before the real computation in the following blocks.
190
+ # 0. Self-Attention
191
+ batch_size = hidden_states.shape[0]
192
+
193
+ if self.use_ada_layer_norm:
194
+ norm_hidden_states = self.norm1(hidden_states, timestep)
195
+ elif self.use_ada_layer_norm_zero:
196
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
197
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
198
+ )
199
+ elif self.use_layer_norm:
200
+ norm_hidden_states = self.norm1(hidden_states)
201
+ elif self.use_ada_layer_norm_single:
202
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
203
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
204
+ ).chunk(6, dim=1)
205
+ norm_hidden_states = self.norm1(hidden_states)
206
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
207
+ norm_hidden_states = norm_hidden_states.squeeze(1)
208
+ else:
209
+ raise ValueError("Incorrect norm used")
210
+
211
+ if self.pos_embed is not None:
212
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
213
+
214
+ # 1. Retrieve lora scale.
215
+ lora_scale = (
216
+ cross_attention_kwargs.get("scale", 1.0)
217
+ if cross_attention_kwargs is not None
218
+ else 1.0
219
+ )
220
+
221
+ # 2. Prepare GLIGEN inputs
222
+ cross_attention_kwargs = (
223
+ cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
224
+ )
225
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
226
+
227
+ attn_output = self.attn1(
228
+ norm_hidden_states,
229
+ encoder_hidden_states=encoder_hidden_states
230
+ if self.only_cross_attention
231
+ else None,
232
+ attention_mask=attention_mask,
233
+ **cross_attention_kwargs,
234
+ )
235
+ if self.use_ada_layer_norm_zero:
236
+ attn_output = gate_msa.unsqueeze(1) * attn_output
237
+ elif self.use_ada_layer_norm_single:
238
+ attn_output = gate_msa * attn_output
239
+
240
+ hidden_states = attn_output + hidden_states
241
+ if hidden_states.ndim == 4:
242
+ hidden_states = hidden_states.squeeze(1)
243
+
244
+ # 2.5 GLIGEN Control
245
+ if gligen_kwargs is not None:
246
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
247
+
248
+ # 3. Cross-Attention
249
+ if self.attn2 is not None:
250
+ if self.use_ada_layer_norm:
251
+ norm_hidden_states = self.norm2(hidden_states, timestep)
252
+ elif self.use_ada_layer_norm_zero or self.use_layer_norm:
253
+ norm_hidden_states = self.norm2(hidden_states)
254
+ elif self.use_ada_layer_norm_single:
255
+ # For PixArt norm2 isn't applied here:
256
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
257
+ norm_hidden_states = hidden_states
258
+ else:
259
+ raise ValueError("Incorrect norm")
260
+
261
+ if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
262
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
263
+
264
+ attn_output = self.attn2(
265
+ norm_hidden_states,
266
+ encoder_hidden_states=encoder_hidden_states,
267
+ attention_mask=encoder_attention_mask,
268
+ **cross_attention_kwargs,
269
+ )
270
+ hidden_states = attn_output + hidden_states
271
+
272
+ # 4. Feed-forward
273
+ if not self.use_ada_layer_norm_single:
274
+ norm_hidden_states = self.norm3(hidden_states)
275
+
276
+ if self.use_ada_layer_norm_zero:
277
+ norm_hidden_states = (
278
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
279
+ )
280
+
281
+ if self.use_ada_layer_norm_single:
282
+ norm_hidden_states = self.norm2(hidden_states)
283
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
284
+
285
+ ff_output = self.ff(norm_hidden_states, scale=lora_scale)
286
+
287
+ if self.use_ada_layer_norm_zero:
288
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
289
+ elif self.use_ada_layer_norm_single:
290
+ ff_output = gate_mlp * ff_output
291
+
292
+ hidden_states = ff_output + hidden_states
293
+ if hidden_states.ndim == 4:
294
+ hidden_states = hidden_states.squeeze(1)
295
+
296
+ return hidden_states
297
+
298
+
299
+ class TemporalBasicTransformerBlock(nn.Module):
300
+ def __init__(
301
+ self,
302
+ dim: int,
303
+ num_attention_heads: int,
304
+ attention_head_dim: int,
305
+ dropout=0.0,
306
+ cross_attention_dim: Optional[int] = None,
307
+ activation_fn: str = "geglu",
308
+ num_embeds_ada_norm: Optional[int] = None,
309
+ attention_bias: bool = False,
310
+ only_cross_attention: bool = False,
311
+ upcast_attention: bool = False,
312
+ unet_use_cross_frame_attention=None,
313
+ unet_use_temporal_attention=None,
314
+ ):
315
+ super().__init__()
316
+ self.only_cross_attention = only_cross_attention
317
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
318
+ self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
319
+ self.unet_use_temporal_attention = unet_use_temporal_attention
320
+
321
+ # SC-Attn
322
+ self.attn1 = Attention(
323
+ query_dim=dim,
324
+ heads=num_attention_heads,
325
+ dim_head=attention_head_dim,
326
+ dropout=dropout,
327
+ bias=attention_bias,
328
+ upcast_attention=upcast_attention,
329
+ )
330
+ self.norm1 = (
331
+ AdaLayerNorm(dim, num_embeds_ada_norm)
332
+ if self.use_ada_layer_norm
333
+ else nn.LayerNorm(dim)
334
+ )
335
+
336
+ # Cross-Attn
337
+ if cross_attention_dim is not None:
338
+ self.attn2 = Attention(
339
+ query_dim=dim,
340
+ cross_attention_dim=cross_attention_dim,
341
+ heads=num_attention_heads,
342
+ dim_head=attention_head_dim,
343
+ dropout=dropout,
344
+ bias=attention_bias,
345
+ upcast_attention=upcast_attention,
346
+ )
347
+ else:
348
+ self.attn2 = None
349
+
350
+ if cross_attention_dim is not None:
351
+ self.norm2 = (
352
+ AdaLayerNorm(dim, num_embeds_ada_norm)
353
+ if self.use_ada_layer_norm
354
+ else nn.LayerNorm(dim)
355
+ )
356
+ else:
357
+ self.norm2 = None
358
+
359
+ # Feed-forward
360
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
361
+ self.norm3 = nn.LayerNorm(dim)
362
+ self.use_ada_layer_norm_zero = False
363
+
364
+ # Temp-Attn
365
+ assert unet_use_temporal_attention is not None
366
+ if unet_use_temporal_attention:
367
+ self.attn_temp = Attention(
368
+ query_dim=dim,
369
+ heads=num_attention_heads,
370
+ dim_head=attention_head_dim,
371
+ dropout=dropout,
372
+ bias=attention_bias,
373
+ upcast_attention=upcast_attention,
374
+ )
375
+ nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
376
+ self.norm_temp = (
377
+ AdaLayerNorm(dim, num_embeds_ada_norm)
378
+ if self.use_ada_layer_norm
379
+ else nn.LayerNorm(dim)
380
+ )
381
+
382
+ def forward(
383
+ self,
384
+ hidden_states,
385
+ encoder_hidden_states=None,
386
+ timestep=None,
387
+ attention_mask=None,
388
+ video_length=None,
389
+ ):
390
+ norm_hidden_states = (
391
+ self.norm1(hidden_states, timestep)
392
+ if self.use_ada_layer_norm
393
+ else self.norm1(hidden_states)
394
+ )
395
+
396
+ if self.unet_use_cross_frame_attention:
397
+ hidden_states = (
398
+ self.attn1(
399
+ norm_hidden_states,
400
+ attention_mask=attention_mask,
401
+ video_length=video_length,
402
+ )
403
+ + hidden_states
404
+ )
405
+ else:
406
+ hidden_states = (
407
+ self.attn1(norm_hidden_states, attention_mask=attention_mask)
408
+ + hidden_states
409
+ )
410
+
411
+ if self.attn2 is not None:
412
+ # Cross-Attention
413
+ norm_hidden_states = (
414
+ self.norm2(hidden_states, timestep)
415
+ if self.use_ada_layer_norm
416
+ else self.norm2(hidden_states)
417
+ )
418
+ hidden_states = (
419
+ self.attn2(
420
+ norm_hidden_states,
421
+ encoder_hidden_states=encoder_hidden_states,
422
+ attention_mask=attention_mask,
423
+ )
424
+ + hidden_states
425
+ )
426
+
427
+ # Feed-forward
428
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
429
+
430
+ # Temporal-Attention
431
+ if self.unet_use_temporal_attention:
432
+ d = hidden_states.shape[1]
433
+ hidden_states = rearrange(
434
+ hidden_states, "(b f) d c -> (b d) f c", f=video_length
435
+ )
436
+ norm_hidden_states = (
437
+ self.norm_temp(hidden_states, timestep)
438
+ if self.use_ada_layer_norm
439
+ else self.norm_temp(hidden_states)
440
+ )
441
+ hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
442
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
443
+
444
+ return hidden_states
src/models/attention_processor.py ADDED
The diff for this file is too large to render. See raw diff
 
src/models/mutual_self_attention_multi_scale.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/magic-research/magic-animate/blob/main/magicanimate/models/mutual_self_attention.py
2
+ from typing import Any, Dict, Optional
3
+
4
+ import torch
5
+ from einops import rearrange
6
+
7
+ from src.models.attention import TemporalBasicTransformerBlock
8
+
9
+ from .attention import BasicTransformerBlock
10
+
11
+
12
+ def torch_dfs(model: torch.nn.Module):
13
+ result = [model]
14
+ for child in model.children():
15
+ result += torch_dfs(child)
16
+ return result
17
+
18
+ def filter_matrices_by_size(matrix_list, reference_matrix):
19
+ ref_shape = reference_matrix.shape[-2:]
20
+ return [matrix for matrix in matrix_list if matrix.shape[-2:] == ref_shape]
21
+
22
+ class ReferenceAttentionControl:
23
+ def __init__(
24
+ self,
25
+ unet,
26
+ mode="write",
27
+ do_classifier_free_guidance=False,
28
+ attention_auto_machine_weight=float("inf"),
29
+ gn_auto_machine_weight=1.0,
30
+ style_fidelity=1.0,
31
+ reference_attn=True,
32
+ reference_adain=False,
33
+ fusion_blocks="midup",
34
+ batch_size=1,
35
+ ) -> None:
36
+ # 10. Modify self attention and group norm
37
+ self.unet = unet
38
+ assert mode in ["read", "write"]
39
+ assert fusion_blocks in ["midup", "full"]
40
+ self.reference_attn = reference_attn
41
+ self.reference_adain = reference_adain
42
+ self.fusion_blocks = fusion_blocks
43
+ self.register_reference_hooks(
44
+ mode,
45
+ do_classifier_free_guidance,
46
+ attention_auto_machine_weight,
47
+ gn_auto_machine_weight,
48
+ style_fidelity,
49
+ reference_attn,
50
+ reference_adain,
51
+ fusion_blocks,
52
+ batch_size=batch_size,
53
+ )
54
+ self.point_embedding=[]
55
+ def register_reference_hooks(
56
+ self,
57
+ mode,
58
+ do_classifier_free_guidance,
59
+ attention_auto_machine_weight,
60
+ gn_auto_machine_weight,
61
+ style_fidelity,
62
+ reference_attn,
63
+ reference_adain,
64
+ dtype=torch.float16,
65
+ batch_size=1,
66
+ num_images_per_prompt=1,
67
+ device=torch.device("cpu"),
68
+ fusion_blocks="midup",
69
+ ):
70
+ MODE = mode
71
+ do_classifier_free_guidance = do_classifier_free_guidance
72
+ attention_auto_machine_weight = attention_auto_machine_weight
73
+ gn_auto_machine_weight = gn_auto_machine_weight
74
+ style_fidelity = style_fidelity
75
+ reference_attn = reference_attn
76
+ reference_adain = reference_adain
77
+ fusion_blocks = fusion_blocks
78
+ num_images_per_prompt = num_images_per_prompt
79
+ dtype = dtype
80
+ if do_classifier_free_guidance:
81
+ uc_mask = (
82
+ torch.Tensor(
83
+ [1] * batch_size * num_images_per_prompt * 16
84
+ + [0] * batch_size * num_images_per_prompt * 16
85
+ )
86
+ .to(device)
87
+ .bool()
88
+ )
89
+ else:
90
+ uc_mask = (
91
+ torch.Tensor([0] * batch_size * num_images_per_prompt * 2)
92
+ .to(device)
93
+ .bool()
94
+ )
95
+
96
+ def hacked_basic_transformer_inner_forward(
97
+ self,
98
+ hidden_states: torch.FloatTensor,
99
+ attention_mask: Optional[torch.FloatTensor] = None,
100
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
101
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
102
+ timestep: Optional[torch.LongTensor] = None,
103
+ cross_attention_kwargs: Dict[str, Any] = None,
104
+ class_labels: Optional[torch.LongTensor] = None,
105
+ video_length=None,
106
+ ):
107
+ if self.use_ada_layer_norm: # False
108
+ norm_hidden_states = self.norm1(hidden_states, timestep)
109
+ elif self.use_ada_layer_norm_zero:
110
+ (
111
+ norm_hidden_states,
112
+ gate_msa,
113
+ shift_mlp,
114
+ scale_mlp,
115
+ gate_mlp,
116
+ ) = self.norm1(
117
+ hidden_states,
118
+ timestep,
119
+ class_labels,
120
+ hidden_dtype=hidden_states.dtype,
121
+ )
122
+ else:
123
+ norm_hidden_states = self.norm1(hidden_states)
124
+
125
+ # 1. Self-Attention
126
+ # self.only_cross_attention = False
127
+ cross_attention_kwargs = (
128
+ cross_attention_kwargs if cross_attention_kwargs is not None else {}
129
+ )
130
+ if self.only_cross_attention:
131
+ attn_output = self.attn1(
132
+ norm_hidden_states,
133
+ encoder_hidden_states=encoder_hidden_states
134
+ if self.only_cross_attention
135
+ else None,
136
+ attention_mask=attention_mask,
137
+ **cross_attention_kwargs,
138
+ )
139
+ else:
140
+ if MODE == "write":
141
+ self.bank.append(norm_hidden_states.clone())
142
+ attn_output = self.attn1(
143
+ norm_hidden_states,
144
+ encoder_hidden_states=encoder_hidden_states
145
+ if self.only_cross_attention
146
+ else None,
147
+ attention_mask=attention_mask,
148
+ **cross_attention_kwargs,
149
+ )
150
+ if MODE == "read":
151
+ bank_fea = [
152
+ rearrange(
153
+ d.unsqueeze(1).repeat(1, 1, 1, 1),
154
+ "b t l c -> (b t) l c",
155
+ )
156
+ for d in self.bank
157
+ ]
158
+ try:
159
+ modify_norm_hidden_states = torch.cat(
160
+ [norm_hidden_states+self.point_bank_main[0].repeat(norm_hidden_states.shape[0],1,1)] + [bank_fea[0]+self.point_bank_ref[0].repeat(norm_hidden_states.shape[0],1,1)], dim=1
161
+ )
162
+ modify_norm_hidden_states_v = torch.cat(
163
+ [norm_hidden_states] + bank_fea, dim=1
164
+ )
165
+ # import ipdb;ipdb.set_trace()
166
+ hidden_states_uc = (
167
+ self.attn1(
168
+ norm_hidden_states+self.point_bank_main[0].repeat(norm_hidden_states.shape[0],1,1),
169
+ encoder_hidden_states=modify_norm_hidden_states,
170
+ encoder_hidden_states_v=modify_norm_hidden_states_v,
171
+ attention_mask=attention_mask,
172
+ )
173
+ + hidden_states
174
+ )
175
+ except:
176
+ modify_norm_hidden_states = torch.cat(
177
+ [norm_hidden_states] + bank_fea, dim=1
178
+ )
179
+ hidden_states_uc = (
180
+ self.attn1(
181
+ norm_hidden_states,
182
+ encoder_hidden_states=modify_norm_hidden_states,
183
+ attention_mask=attention_mask,
184
+ )
185
+ + hidden_states
186
+ )
187
+ if do_classifier_free_guidance:
188
+ hidden_states_c = hidden_states_uc.clone()
189
+ _uc_mask = uc_mask.clone()
190
+ if hidden_states.shape[0] != _uc_mask.shape[0]:
191
+ _uc_mask = (
192
+ torch.Tensor(
193
+ [1] * (hidden_states.shape[0] // 3)
194
+ + [0] * (hidden_states.shape[0] // 3)
195
+ + [0] * (hidden_states.shape[0] // 3)
196
+ )
197
+ .to(device)
198
+ .bool()
199
+ )
200
+ _uc_mask_2 = (
201
+ torch.Tensor(
202
+ [0] * (hidden_states.shape[0] // 3)
203
+ + [1] * (hidden_states.shape[0] // 3)
204
+ + [0] * (hidden_states.shape[0] // 3)
205
+ )
206
+ .to(device)
207
+ .bool()
208
+ )
209
+ hidden_states_c[_uc_mask] = (
210
+ self.attn1(
211
+ norm_hidden_states[_uc_mask],
212
+ encoder_hidden_states=norm_hidden_states[_uc_mask],
213
+ attention_mask=attention_mask,
214
+ )
215
+ + hidden_states[_uc_mask]
216
+ )
217
+ modify_norm_hidden_states = torch.cat(
218
+ [norm_hidden_states] + bank_fea, dim=1
219
+ )
220
+ hidden_states_c[_uc_mask_2] = (
221
+ self.attn1(
222
+ norm_hidden_states[_uc_mask_2],
223
+ encoder_hidden_states=modify_norm_hidden_states[_uc_mask_2],
224
+ attention_mask=attention_mask,
225
+ )
226
+ + hidden_states[_uc_mask_2]
227
+ )
228
+ hidden_states = hidden_states_c.clone()
229
+ else:
230
+ hidden_states = hidden_states_uc
231
+
232
+
233
+ if self.attn2 is not None:
234
+ # Cross-Attention
235
+ norm_hidden_states = (
236
+ self.norm2(hidden_states, timestep)
237
+ if self.use_ada_layer_norm
238
+ else self.norm2(hidden_states)
239
+ )
240
+ hidden_states = (
241
+ self.attn2(
242
+ norm_hidden_states,
243
+ encoder_hidden_states=encoder_hidden_states,
244
+ attention_mask=attention_mask,
245
+ )
246
+ + hidden_states
247
+ )
248
+
249
+ # Feed-forward
250
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
251
+ return hidden_states
252
+ # import ipdb;ipdb.set_trace()
253
+ if self.use_ada_layer_norm_zero:
254
+ attn_output = gate_msa.unsqueeze(1) * attn_output
255
+ try:
256
+ hidden_states = attn_output + hidden_states
257
+ except:
258
+ import ipdb;ipdb.set_trace()
259
+ if self.attn2 is not None:
260
+ norm_hidden_states = (
261
+ self.norm2(hidden_states, timestep)
262
+ if self.use_ada_layer_norm
263
+ else self.norm2(hidden_states)
264
+ )
265
+
266
+ # 2. Cross-Attention
267
+ attn_output = self.attn2(
268
+ norm_hidden_states,
269
+ encoder_hidden_states=encoder_hidden_states,
270
+ attention_mask=encoder_attention_mask,
271
+ **cross_attention_kwargs,
272
+ )
273
+ hidden_states = attn_output + hidden_states
274
+
275
+ # 3. Feed-forward
276
+ norm_hidden_states = self.norm3(hidden_states)
277
+
278
+ if self.use_ada_layer_norm_zero:
279
+ norm_hidden_states = (
280
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
281
+ )
282
+
283
+ ff_output = self.ff(norm_hidden_states)
284
+
285
+ if self.use_ada_layer_norm_zero:
286
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
287
+
288
+ hidden_states = ff_output + hidden_states
289
+
290
+ return hidden_states
291
+
292
+ if self.reference_attn:
293
+ if self.fusion_blocks == "midup":
294
+ attn_modules = [
295
+ module
296
+ for module in (
297
+ torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
298
+ )
299
+ if isinstance(module, BasicTransformerBlock)
300
+ ]
301
+ elif self.fusion_blocks == "full":
302
+ attn_modules = [
303
+ module
304
+ for module in torch_dfs(self.unet)
305
+ if isinstance(module, BasicTransformerBlock)
306
+ ]
307
+ attn_modules = sorted(
308
+ attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
309
+ )
310
+
311
+ for i, module in enumerate(attn_modules):
312
+ module._original_inner_forward = module.forward
313
+ if isinstance(module, BasicTransformerBlock):
314
+ module.forward = hacked_basic_transformer_inner_forward.__get__(
315
+ module, BasicTransformerBlock
316
+ )
317
+ module.bank = []
318
+ module.point_bank_ref=[]
319
+ module.point_bank_main=[]
320
+ module.attn_weight = float(i) / float(len(attn_modules))
321
+
322
+ def update(self, writer,point_embedding_ref=None,point_embedding_main=None,dtype=torch.float16):
323
+ if self.reference_attn:
324
+ if self.fusion_blocks == "midup":
325
+ reader_attn_modules = [
326
+ module
327
+ for module in (
328
+ torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
329
+ )
330
+ if isinstance(module, TemporalBasicTransformerBlock)
331
+ ]
332
+ writer_attn_modules = [
333
+ module
334
+ for module in (
335
+ torch_dfs(writer.unet.mid_block)
336
+ + torch_dfs(writer.unet.up_blocks)
337
+ )
338
+ if isinstance(module, BasicTransformerBlock)
339
+ ]
340
+ elif self.fusion_blocks == "full":
341
+ reader_attn_modules = [
342
+ module
343
+ for module in torch_dfs(self.unet)
344
+ if isinstance(module, BasicTransformerBlock)
345
+ ]
346
+ writer_attn_modules = [
347
+ module
348
+ for module in torch_dfs(writer.unet)
349
+ if isinstance(module, BasicTransformerBlock)
350
+ ]
351
+ reader_attn_modules = sorted(
352
+ reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
353
+ )
354
+ writer_attn_modules = sorted(
355
+ writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
356
+ )
357
+ # import ipdb;ipdb.set_trace()
358
+ for r, w in zip(reader_attn_modules, writer_attn_modules):
359
+ r.bank = [v.clone().to(dtype) for v in w.bank]
360
+ if point_embedding_main is not None:
361
+ r.point_bank_ref=filter_matrices_by_size(point_embedding_ref, r.bank[0])
362
+ r.point_bank_main=filter_matrices_by_size(point_embedding_main, r.bank[0])
363
+ # w.bank.clear()
364
+
365
+ def clear(self):
366
+ if self.reference_attn:
367
+ if self.fusion_blocks == "midup":
368
+ reader_attn_modules = [
369
+ module
370
+ for module in (
371
+ torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
372
+ )
373
+ if isinstance(module, BasicTransformerBlock)
374
+ or isinstance(module, TemporalBasicTransformerBlock)
375
+ ]
376
+ elif self.fusion_blocks == "full":
377
+ reader_attn_modules = [
378
+ module
379
+ for module in torch_dfs(self.unet)
380
+ if isinstance(module, BasicTransformerBlock)
381
+ or isinstance(module, TemporalBasicTransformerBlock)
382
+ ]
383
+ reader_attn_modules = sorted(
384
+ reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
385
+ )
386
+ for r in reader_attn_modules:
387
+ r.bank.clear()
388
+ r.point_bank_ref.clear()
389
+ r.point_bank_main.clear()
src/models/refunet_2d_condition.py ADDED
@@ -0,0 +1,1307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
2
+ from dataclasses import dataclass
3
+ from typing import Any, Dict, List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.utils.checkpoint
8
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
9
+ from diffusers.loaders import UNet2DConditionLoadersMixin
10
+ from diffusers.models.activations import get_activation
11
+ from diffusers.models.attention_processor import (
12
+ ADDED_KV_ATTENTION_PROCESSORS,
13
+ CROSS_ATTENTION_PROCESSORS,
14
+ AttentionProcessor,
15
+ AttnAddedKVProcessor,
16
+ AttnProcessor,
17
+ )
18
+ from diffusers.models.embeddings import (
19
+ GaussianFourierProjection,
20
+ ImageHintTimeEmbedding,
21
+ ImageProjection,
22
+ ImageTimeEmbedding,
23
+ TextImageProjection,
24
+ TextImageTimeEmbedding,
25
+ TextTimeEmbedding,
26
+ TimestepEmbedding,
27
+ Timesteps,
28
+ )
29
+ from diffusers.models.modeling_utils import ModelMixin
30
+ from diffusers.utils import (
31
+ USE_PEFT_BACKEND,
32
+ BaseOutput,
33
+ deprecate,
34
+ logging,
35
+ scale_lora_layers,
36
+ unscale_lora_layers,
37
+ )
38
+
39
+ from .unet_2d_blocks import (
40
+ UNetMidBlock2D,
41
+ UNetMidBlock2DCrossAttn,
42
+ get_down_block,
43
+ get_up_block,
44
+ )
45
+
46
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
47
+
48
+
49
+ @dataclass
50
+ class UNet2DConditionOutput(BaseOutput):
51
+ """
52
+ The output of [`UNet2DConditionModel`].
53
+
54
+ Args:
55
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
56
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
57
+ """
58
+
59
+ sample: torch.FloatTensor = None
60
+ ref_features: Tuple[torch.FloatTensor] = None
61
+
62
+
63
+ class RefUNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
64
+ r"""
65
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
66
+ shaped output.
67
+
68
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
69
+ for all models (such as downloading or saving).
70
+
71
+ Parameters:
72
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
73
+ Height and width of input/output sample.
74
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
75
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
76
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
77
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
78
+ Whether to flip the sin to cos in the time embedding.
79
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
80
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
81
+ The tuple of downsample blocks to use.
82
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
83
+ Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
84
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
85
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
86
+ The tuple of upsample blocks to use.
87
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
88
+ Whether to include self-attention in the basic transformer blocks, see
89
+ [`~models.attention.BasicTransformerBlock`].
90
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
91
+ The tuple of output channels for each block.
92
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
93
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
94
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
95
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
96
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
97
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
98
+ If `None`, normalization and activation layers is skipped in post-processing.
99
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
100
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
101
+ The dimension of the cross attention features.
102
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
103
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
104
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
105
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
106
+ reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
107
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
108
+ blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
109
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
110
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
111
+ encoder_hid_dim (`int`, *optional*, defaults to None):
112
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
113
+ dimension to `cross_attention_dim`.
114
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
115
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
116
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
117
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
118
+ num_attention_heads (`int`, *optional*):
119
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
120
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
121
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
122
+ class_embed_type (`str`, *optional*, defaults to `None`):
123
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
124
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
125
+ addition_embed_type (`str`, *optional*, defaults to `None`):
126
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
127
+ "text". "text" will use the `TextTimeEmbedding` layer.
128
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
129
+ Dimension for the timestep embeddings.
130
+ num_class_embeds (`int`, *optional*, defaults to `None`):
131
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
132
+ class conditioning with `class_embed_type` equal to `None`.
133
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
134
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
135
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
136
+ An optional override for the dimension of the projected time embedding.
137
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
138
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
139
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
140
+ timestep_post_act (`str`, *optional*, defaults to `None`):
141
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
142
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
143
+ The dimension of `cond_proj` layer in the timestep embedding.
144
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`,
145
+ *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`,
146
+ *optional*): The dimension of the `class_labels` input when
147
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
148
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
149
+ embeddings with the class embeddings.
150
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
151
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
152
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
153
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
154
+ otherwise.
155
+ """
156
+
157
+ _supports_gradient_checkpointing = True
158
+
159
+ @register_to_config
160
+ def __init__(
161
+ self,
162
+ sample_size: Optional[int] = None,
163
+ in_channels: int = 4,
164
+ out_channels: int = 4,
165
+ center_input_sample: bool = False,
166
+ flip_sin_to_cos: bool = True,
167
+ freq_shift: int = 0,
168
+ down_block_types: Tuple[str] = (
169
+ "CrossAttnDownBlock2D",
170
+ "CrossAttnDownBlock2D",
171
+ "CrossAttnDownBlock2D",
172
+ "DownBlock2D",
173
+ ),
174
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
175
+ up_block_types: Tuple[str] = (
176
+ "UpBlock2D",
177
+ "CrossAttnUpBlock2D",
178
+ "CrossAttnUpBlock2D",
179
+ "CrossAttnUpBlock2D",
180
+ ),
181
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
182
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
183
+ layers_per_block: Union[int, Tuple[int]] = 2,
184
+ downsample_padding: int = 1,
185
+ mid_block_scale_factor: float = 1,
186
+ dropout: float = 0.0,
187
+ act_fn: str = "silu",
188
+ norm_num_groups: Optional[int] = 32,
189
+ norm_eps: float = 1e-5,
190
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
191
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
192
+ reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
193
+ encoder_hid_dim: Optional[int] = None,
194
+ encoder_hid_dim_type: Optional[str] = None,
195
+ attention_head_dim: Union[int, Tuple[int]] = 8,
196
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
197
+ dual_cross_attention: bool = False,
198
+ use_linear_projection: bool = False,
199
+ class_embed_type: Optional[str] = None,
200
+ addition_embed_type: Optional[str] = None,
201
+ addition_time_embed_dim: Optional[int] = None,
202
+ num_class_embeds: Optional[int] = None,
203
+ upcast_attention: bool = False,
204
+ resnet_time_scale_shift: str = "default",
205
+ resnet_skip_time_act: bool = False,
206
+ resnet_out_scale_factor: int = 1.0,
207
+ time_embedding_type: str = "positional",
208
+ time_embedding_dim: Optional[int] = None,
209
+ time_embedding_act_fn: Optional[str] = None,
210
+ timestep_post_act: Optional[str] = None,
211
+ time_cond_proj_dim: Optional[int] = None,
212
+ conv_in_kernel: int = 3,
213
+ conv_out_kernel: int = 3,
214
+ projection_class_embeddings_input_dim: Optional[int] = None,
215
+ attention_type: str = "default",
216
+ class_embeddings_concat: bool = False,
217
+ mid_block_only_cross_attention: Optional[bool] = None,
218
+ cross_attention_norm: Optional[str] = None,
219
+ addition_embed_type_num_heads=64,
220
+ ):
221
+ super().__init__()
222
+
223
+ self.sample_size = sample_size
224
+
225
+ if num_attention_heads is not None:
226
+ raise ValueError(
227
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
228
+ )
229
+
230
+ # If `num_attention_heads` is not defined (which is the case for most models)
231
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
232
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
233
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
234
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
235
+ # which is why we correct for the naming here.
236
+ num_attention_heads = num_attention_heads or attention_head_dim
237
+
238
+ # Check inputs
239
+ if len(down_block_types) != len(up_block_types):
240
+ raise ValueError(
241
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
242
+ )
243
+
244
+ if len(block_out_channels) != len(down_block_types):
245
+ raise ValueError(
246
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
247
+ )
248
+
249
+ if not isinstance(only_cross_attention, bool) and len(
250
+ only_cross_attention
251
+ ) != len(down_block_types):
252
+ raise ValueError(
253
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
254
+ )
255
+
256
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(
257
+ down_block_types
258
+ ):
259
+ raise ValueError(
260
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
261
+ )
262
+
263
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(
264
+ down_block_types
265
+ ):
266
+ raise ValueError(
267
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
268
+ )
269
+
270
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(
271
+ down_block_types
272
+ ):
273
+ raise ValueError(
274
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
275
+ )
276
+
277
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(
278
+ down_block_types
279
+ ):
280
+ raise ValueError(
281
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
282
+ )
283
+ if (
284
+ isinstance(transformer_layers_per_block, list)
285
+ and reverse_transformer_layers_per_block is None
286
+ ):
287
+ for layer_number_per_block in transformer_layers_per_block:
288
+ if isinstance(layer_number_per_block, list):
289
+ raise ValueError(
290
+ "Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet."
291
+ )
292
+
293
+ # input
294
+ conv_in_padding = (conv_in_kernel - 1) // 2
295
+ self.conv_in = nn.Conv2d(
296
+ in_channels,
297
+ block_out_channels[0],
298
+ kernel_size=conv_in_kernel,
299
+ padding=conv_in_padding,
300
+ )
301
+
302
+ # time
303
+ if time_embedding_type == "fourier":
304
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
305
+ if time_embed_dim % 2 != 0:
306
+ raise ValueError(
307
+ f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}."
308
+ )
309
+ self.time_proj = GaussianFourierProjection(
310
+ time_embed_dim // 2,
311
+ set_W_to_weight=False,
312
+ log=False,
313
+ flip_sin_to_cos=flip_sin_to_cos,
314
+ )
315
+ timestep_input_dim = time_embed_dim
316
+ elif time_embedding_type == "positional":
317
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
318
+
319
+ self.time_proj = Timesteps(
320
+ block_out_channels[0], flip_sin_to_cos, freq_shift
321
+ )
322
+ timestep_input_dim = block_out_channels[0]
323
+ else:
324
+ raise ValueError(
325
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
326
+ )
327
+
328
+ self.time_embedding = TimestepEmbedding(
329
+ timestep_input_dim,
330
+ time_embed_dim,
331
+ act_fn=act_fn,
332
+ post_act_fn=timestep_post_act,
333
+ cond_proj_dim=time_cond_proj_dim,
334
+ )
335
+
336
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
337
+ encoder_hid_dim_type = "text_proj"
338
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
339
+ logger.info(
340
+ "encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined."
341
+ )
342
+
343
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
344
+ raise ValueError(
345
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
346
+ )
347
+
348
+ if encoder_hid_dim_type == "text_proj":
349
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
350
+ elif encoder_hid_dim_type == "text_image_proj":
351
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
352
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
353
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
354
+ self.encoder_hid_proj = TextImageProjection(
355
+ text_embed_dim=encoder_hid_dim,
356
+ image_embed_dim=cross_attention_dim,
357
+ cross_attention_dim=cross_attention_dim,
358
+ )
359
+ elif encoder_hid_dim_type == "image_proj":
360
+ # Kandinsky 2.2
361
+ self.encoder_hid_proj = ImageProjection(
362
+ image_embed_dim=encoder_hid_dim,
363
+ cross_attention_dim=cross_attention_dim,
364
+ )
365
+ elif encoder_hid_dim_type is not None:
366
+ raise ValueError(
367
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
368
+ )
369
+ else:
370
+ self.encoder_hid_proj = None
371
+
372
+ # class embedding
373
+ if class_embed_type is None and num_class_embeds is not None:
374
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
375
+ elif class_embed_type == "timestep":
376
+ self.class_embedding = TimestepEmbedding(
377
+ timestep_input_dim, time_embed_dim, act_fn=act_fn
378
+ )
379
+ elif class_embed_type == "identity":
380
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
381
+ elif class_embed_type == "projection":
382
+ if projection_class_embeddings_input_dim is None:
383
+ raise ValueError(
384
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
385
+ )
386
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
387
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
388
+ # 2. it projects from an arbitrary input dimension.
389
+ #
390
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
391
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
392
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
393
+ self.class_embedding = TimestepEmbedding(
394
+ projection_class_embeddings_input_dim, time_embed_dim
395
+ )
396
+ elif class_embed_type == "simple_projection":
397
+ if projection_class_embeddings_input_dim is None:
398
+ raise ValueError(
399
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
400
+ )
401
+ self.class_embedding = nn.Linear(
402
+ projection_class_embeddings_input_dim, time_embed_dim
403
+ )
404
+ else:
405
+ self.class_embedding = None
406
+
407
+ if addition_embed_type == "text":
408
+ if encoder_hid_dim is not None:
409
+ text_time_embedding_from_dim = encoder_hid_dim
410
+ else:
411
+ text_time_embedding_from_dim = cross_attention_dim
412
+
413
+ self.add_embedding = TextTimeEmbedding(
414
+ text_time_embedding_from_dim,
415
+ time_embed_dim,
416
+ num_heads=addition_embed_type_num_heads,
417
+ )
418
+ elif addition_embed_type == "text_image":
419
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
420
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
421
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
422
+ self.add_embedding = TextImageTimeEmbedding(
423
+ text_embed_dim=cross_attention_dim,
424
+ image_embed_dim=cross_attention_dim,
425
+ time_embed_dim=time_embed_dim,
426
+ )
427
+ elif addition_embed_type == "text_time":
428
+ self.add_time_proj = Timesteps(
429
+ addition_time_embed_dim, flip_sin_to_cos, freq_shift
430
+ )
431
+ self.add_embedding = TimestepEmbedding(
432
+ projection_class_embeddings_input_dim, time_embed_dim
433
+ )
434
+ elif addition_embed_type == "image":
435
+ # Kandinsky 2.2
436
+ self.add_embedding = ImageTimeEmbedding(
437
+ image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim
438
+ )
439
+ elif addition_embed_type == "image_hint":
440
+ # Kandinsky 2.2 ControlNet
441
+ self.add_embedding = ImageHintTimeEmbedding(
442
+ image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim
443
+ )
444
+ elif addition_embed_type is not None:
445
+ raise ValueError(
446
+ f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'."
447
+ )
448
+
449
+ if time_embedding_act_fn is None:
450
+ self.time_embed_act = None
451
+ else:
452
+ self.time_embed_act = get_activation(time_embedding_act_fn)
453
+
454
+ self.down_blocks = nn.ModuleList([])
455
+ self.up_blocks = nn.ModuleList([])
456
+
457
+ if isinstance(only_cross_attention, bool):
458
+ if mid_block_only_cross_attention is None:
459
+ mid_block_only_cross_attention = only_cross_attention
460
+
461
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
462
+
463
+ if mid_block_only_cross_attention is None:
464
+ mid_block_only_cross_attention = False
465
+
466
+ if isinstance(num_attention_heads, int):
467
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
468
+
469
+ if isinstance(attention_head_dim, int):
470
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
471
+
472
+ if isinstance(cross_attention_dim, int):
473
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
474
+
475
+ if isinstance(layers_per_block, int):
476
+ layers_per_block = [layers_per_block] * len(down_block_types)
477
+
478
+ if isinstance(transformer_layers_per_block, int):
479
+ transformer_layers_per_block = [transformer_layers_per_block] * len(
480
+ down_block_types
481
+ )
482
+
483
+ if class_embeddings_concat:
484
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
485
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
486
+ # regular time embeddings
487
+ blocks_time_embed_dim = time_embed_dim * 2
488
+ else:
489
+ blocks_time_embed_dim = time_embed_dim
490
+
491
+ # down
492
+ output_channel = block_out_channels[0]
493
+ for i, down_block_type in enumerate(down_block_types):
494
+ input_channel = output_channel
495
+ output_channel = block_out_channels[i]
496
+ is_final_block = i == len(block_out_channels) - 1
497
+
498
+ down_block = get_down_block(
499
+ down_block_type,
500
+ num_layers=layers_per_block[i],
501
+ transformer_layers_per_block=transformer_layers_per_block[i],
502
+ in_channels=input_channel,
503
+ out_channels=output_channel,
504
+ temb_channels=blocks_time_embed_dim,
505
+ add_downsample=not is_final_block,
506
+ resnet_eps=norm_eps,
507
+ resnet_act_fn=act_fn,
508
+ resnet_groups=norm_num_groups,
509
+ cross_attention_dim=cross_attention_dim[i],
510
+ num_attention_heads=num_attention_heads[i],
511
+ downsample_padding=downsample_padding,
512
+ dual_cross_attention=dual_cross_attention,
513
+ use_linear_projection=use_linear_projection,
514
+ only_cross_attention=only_cross_attention[i],
515
+ upcast_attention=upcast_attention,
516
+ resnet_time_scale_shift=resnet_time_scale_shift,
517
+ attention_type=attention_type,
518
+ resnet_skip_time_act=resnet_skip_time_act,
519
+ resnet_out_scale_factor=resnet_out_scale_factor,
520
+ cross_attention_norm=cross_attention_norm,
521
+ attention_head_dim=attention_head_dim[i]
522
+ if attention_head_dim[i] is not None
523
+ else output_channel,
524
+ dropout=dropout,
525
+ )
526
+ self.down_blocks.append(down_block)
527
+
528
+ # mid
529
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
530
+ self.mid_block = UNetMidBlock2DCrossAttn(
531
+ transformer_layers_per_block=transformer_layers_per_block[-1],
532
+ in_channels=block_out_channels[-1],
533
+ temb_channels=blocks_time_embed_dim,
534
+ dropout=dropout,
535
+ resnet_eps=norm_eps,
536
+ resnet_act_fn=act_fn,
537
+ output_scale_factor=mid_block_scale_factor,
538
+ resnet_time_scale_shift=resnet_time_scale_shift,
539
+ cross_attention_dim=cross_attention_dim[-1],
540
+ num_attention_heads=num_attention_heads[-1],
541
+ resnet_groups=norm_num_groups,
542
+ dual_cross_attention=dual_cross_attention,
543
+ use_linear_projection=use_linear_projection,
544
+ upcast_attention=upcast_attention,
545
+ attention_type=attention_type,
546
+ )
547
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
548
+ raise NotImplementedError(f"Unsupport mid_block_type: {mid_block_type}")
549
+ elif mid_block_type == "UNetMidBlock2D":
550
+ self.mid_block = UNetMidBlock2D(
551
+ in_channels=block_out_channels[-1],
552
+ temb_channels=blocks_time_embed_dim,
553
+ dropout=dropout,
554
+ num_layers=0,
555
+ resnet_eps=norm_eps,
556
+ resnet_act_fn=act_fn,
557
+ output_scale_factor=mid_block_scale_factor,
558
+ resnet_groups=norm_num_groups,
559
+ resnet_time_scale_shift=resnet_time_scale_shift,
560
+ add_attention=False,
561
+ )
562
+ elif mid_block_type is None:
563
+ self.mid_block = None
564
+ else:
565
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
566
+
567
+ # count how many layers upsample the images
568
+ self.num_upsamplers = 0
569
+
570
+ # up
571
+ reversed_block_out_channels = list(reversed(block_out_channels))
572
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
573
+ reversed_layers_per_block = list(reversed(layers_per_block))
574
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
575
+ reversed_transformer_layers_per_block = (
576
+ list(reversed(transformer_layers_per_block))
577
+ if reverse_transformer_layers_per_block is None
578
+ else reverse_transformer_layers_per_block
579
+ )
580
+ only_cross_attention = list(reversed(only_cross_attention))
581
+
582
+ output_channel = reversed_block_out_channels[0]
583
+ for i, up_block_type in enumerate(up_block_types):
584
+ is_final_block = i == len(block_out_channels) - 1
585
+
586
+ prev_output_channel = output_channel
587
+ output_channel = reversed_block_out_channels[i]
588
+ input_channel = reversed_block_out_channels[
589
+ min(i + 1, len(block_out_channels) - 1)
590
+ ]
591
+
592
+ # add upsample block for all BUT final layer
593
+ if not is_final_block:
594
+ add_upsample = True
595
+ self.num_upsamplers += 1
596
+ else:
597
+ add_upsample = False
598
+
599
+ up_block = get_up_block(
600
+ up_block_type,
601
+ num_layers=reversed_layers_per_block[i] + 1,
602
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
603
+ in_channels=input_channel,
604
+ out_channels=output_channel,
605
+ prev_output_channel=prev_output_channel,
606
+ temb_channels=blocks_time_embed_dim,
607
+ add_upsample=add_upsample,
608
+ resnet_eps=norm_eps,
609
+ resnet_act_fn=act_fn,
610
+ resolution_idx=i,
611
+ resnet_groups=norm_num_groups,
612
+ cross_attention_dim=reversed_cross_attention_dim[i],
613
+ num_attention_heads=reversed_num_attention_heads[i],
614
+ dual_cross_attention=dual_cross_attention,
615
+ use_linear_projection=use_linear_projection,
616
+ only_cross_attention=only_cross_attention[i],
617
+ upcast_attention=upcast_attention,
618
+ resnet_time_scale_shift=resnet_time_scale_shift,
619
+ attention_type=attention_type,
620
+ resnet_skip_time_act=resnet_skip_time_act,
621
+ resnet_out_scale_factor=resnet_out_scale_factor,
622
+ cross_attention_norm=cross_attention_norm,
623
+ attention_head_dim=attention_head_dim[i]
624
+ if attention_head_dim[i] is not None
625
+ else output_channel,
626
+ dropout=dropout,
627
+ )
628
+ self.up_blocks.append(up_block)
629
+ prev_output_channel = output_channel
630
+
631
+ # out
632
+ if norm_num_groups is not None:
633
+ self.conv_norm_out = nn.GroupNorm(
634
+ num_channels=block_out_channels[0],
635
+ num_groups=norm_num_groups,
636
+ eps=norm_eps,
637
+ )
638
+
639
+ self.conv_act = get_activation(act_fn)
640
+
641
+ else:
642
+ self.conv_norm_out = None
643
+ self.conv_act = None
644
+ self.conv_norm_out = None
645
+
646
+ conv_out_padding = (conv_out_kernel - 1) // 2
647
+ # self.conv_out = nn.Conv2d(
648
+ # block_out_channels[0],
649
+ # out_channels,
650
+ # kernel_size=conv_out_kernel,
651
+ # padding=conv_out_padding,
652
+ # )
653
+
654
+ if attention_type in ["gated", "gated-text-image"]:
655
+ positive_len = 768
656
+ if isinstance(cross_attention_dim, int):
657
+ positive_len = cross_attention_dim
658
+ elif isinstance(cross_attention_dim, tuple) or isinstance(
659
+ cross_attention_dim, list
660
+ ):
661
+ positive_len = cross_attention_dim[0]
662
+
663
+ feature_type = "text-only" if attention_type == "gated" else "text-image"
664
+ # self.position_net = PositionNet(
665
+ # positive_len=positive_len,
666
+ # out_dim=cross_attention_dim,
667
+ # feature_type=feature_type,
668
+ # )
669
+
670
+ @property
671
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
672
+ r"""
673
+ Returns:
674
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
675
+ indexed by its weight name.
676
+ """
677
+ # set recursively
678
+ processors = {}
679
+
680
+ def fn_recursive_add_processors(
681
+ name: str,
682
+ module: torch.nn.Module,
683
+ processors: Dict[str, AttentionProcessor],
684
+ ):
685
+ if hasattr(module, "get_processor"):
686
+ processors[f"{name}.processor"] = module.get_processor(
687
+ return_deprecated_lora=True
688
+ )
689
+
690
+ for sub_name, child in module.named_children():
691
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
692
+
693
+ return processors
694
+
695
+ for name, module in self.named_children():
696
+ fn_recursive_add_processors(name, module, processors)
697
+
698
+ return processors
699
+
700
+ def set_attn_processor(
701
+ self,
702
+ processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]],
703
+ _remove_lora=False,
704
+ ):
705
+ r"""
706
+ Sets the attention processor to use to compute attention.
707
+
708
+ Parameters:
709
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
710
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
711
+ for **all** `Attention` layers.
712
+
713
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
714
+ processor. This is strongly recommended when setting trainable attention processors.
715
+
716
+ """
717
+ count = len(self.attn_processors.keys())
718
+
719
+ if isinstance(processor, dict) and len(processor) != count:
720
+ raise ValueError(
721
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
722
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
723
+ )
724
+
725
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
726
+ if hasattr(module, "set_processor"):
727
+ if not isinstance(processor, dict):
728
+ module.set_processor(processor, _remove_lora=_remove_lora)
729
+ else:
730
+ module.set_processor(
731
+ processor.pop(f"{name}.processor"), _remove_lora=_remove_lora
732
+ )
733
+
734
+ for sub_name, child in module.named_children():
735
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
736
+
737
+ for name, module in self.named_children():
738
+ fn_recursive_attn_processor(name, module, processor)
739
+
740
+ def set_default_attn_processor(self):
741
+ """
742
+ Disables custom attention processors and sets the default attention implementation.
743
+ """
744
+ if all(
745
+ proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS
746
+ for proc in self.attn_processors.values()
747
+ ):
748
+ processor = AttnAddedKVProcessor()
749
+ elif all(
750
+ proc.__class__ in CROSS_ATTENTION_PROCESSORS
751
+ for proc in self.attn_processors.values()
752
+ ):
753
+ processor = AttnProcessor()
754
+ else:
755
+ raise ValueError(
756
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
757
+ )
758
+
759
+ self.set_attn_processor(processor, _remove_lora=True)
760
+
761
+ def set_attention_slice(self, slice_size):
762
+ r"""
763
+ Enable sliced attention computation.
764
+
765
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
766
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
767
+
768
+ Args:
769
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
770
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
771
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
772
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
773
+ must be a multiple of `slice_size`.
774
+ """
775
+ sliceable_head_dims = []
776
+
777
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
778
+ if hasattr(module, "set_attention_slice"):
779
+ sliceable_head_dims.append(module.sliceable_head_dim)
780
+
781
+ for child in module.children():
782
+ fn_recursive_retrieve_sliceable_dims(child)
783
+
784
+ # retrieve number of attention layers
785
+ for module in self.children():
786
+ fn_recursive_retrieve_sliceable_dims(module)
787
+
788
+ num_sliceable_layers = len(sliceable_head_dims)
789
+
790
+ if slice_size == "auto":
791
+ # half the attention head size is usually a good trade-off between
792
+ # speed and memory
793
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
794
+ elif slice_size == "max":
795
+ # make smallest slice possible
796
+ slice_size = num_sliceable_layers * [1]
797
+
798
+ slice_size = (
799
+ num_sliceable_layers * [slice_size]
800
+ if not isinstance(slice_size, list)
801
+ else slice_size
802
+ )
803
+
804
+ if len(slice_size) != len(sliceable_head_dims):
805
+ raise ValueError(
806
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
807
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
808
+ )
809
+
810
+ for i in range(len(slice_size)):
811
+ size = slice_size[i]
812
+ dim = sliceable_head_dims[i]
813
+ if size is not None and size > dim:
814
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
815
+
816
+ # Recursively walk through all the children.
817
+ # Any children which exposes the set_attention_slice method
818
+ # gets the message
819
+ def fn_recursive_set_attention_slice(
820
+ module: torch.nn.Module, slice_size: List[int]
821
+ ):
822
+ if hasattr(module, "set_attention_slice"):
823
+ module.set_attention_slice(slice_size.pop())
824
+
825
+ for child in module.children():
826
+ fn_recursive_set_attention_slice(child, slice_size)
827
+
828
+ reversed_slice_size = list(reversed(slice_size))
829
+ for module in self.children():
830
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
831
+
832
+ def _set_gradient_checkpointing(self, module, value=False):
833
+ if hasattr(module, "gradient_checkpointing"):
834
+ module.gradient_checkpointing = value
835
+
836
+ def enable_freeu(self, s1, s2, b1, b2):
837
+ r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
838
+
839
+ The suffixes after the scaling factors represent the stage blocks where they are being applied.
840
+
841
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
842
+ are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
843
+
844
+ Args:
845
+ s1 (`float`):
846
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
847
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
848
+ s2 (`float`):
849
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
850
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
851
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
852
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
853
+ """
854
+ for i, upsample_block in enumerate(self.up_blocks):
855
+ setattr(upsample_block, "s1", s1)
856
+ setattr(upsample_block, "s2", s2)
857
+ setattr(upsample_block, "b1", b1)
858
+ setattr(upsample_block, "b2", b2)
859
+
860
+ def disable_freeu(self):
861
+ """Disables the FreeU mechanism."""
862
+ freeu_keys = {"s1", "s2", "b1", "b2"}
863
+ for i, upsample_block in enumerate(self.up_blocks):
864
+ for k in freeu_keys:
865
+ if (
866
+ hasattr(upsample_block, k)
867
+ or getattr(upsample_block, k, None) is not None
868
+ ):
869
+ setattr(upsample_block, k, None)
870
+
871
+ def forward(
872
+ self,
873
+ sample: torch.FloatTensor,
874
+ timestep: Union[torch.Tensor, float, int],
875
+ encoder_hidden_states: torch.Tensor,
876
+ class_labels: Optional[torch.Tensor] = None,
877
+ timestep_cond: Optional[torch.Tensor] = None,
878
+ attention_mask: Optional[torch.Tensor] = None,
879
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
880
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
881
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
882
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
883
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
884
+ encoder_attention_mask: Optional[torch.Tensor] = None,
885
+ return_dict: bool = True,
886
+ ) -> Union[UNet2DConditionOutput, Tuple]:
887
+ r"""
888
+ The [`UNet2DConditionModel`] forward method.
889
+
890
+ Args:
891
+ sample (`torch.FloatTensor`):
892
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
893
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
894
+ encoder_hidden_states (`torch.FloatTensor`):
895
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
896
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
897
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
898
+ timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
899
+ Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
900
+ through the `self.time_embedding` layer to obtain the timestep embeddings.
901
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
902
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
903
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
904
+ negative values to the attention scores corresponding to "discard" tokens.
905
+ cross_attention_kwargs (`dict`, *optional*):
906
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
907
+ `self.processor` in
908
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
909
+ added_cond_kwargs: (`dict`, *optional*):
910
+ A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
911
+ are passed along to the UNet blocks.
912
+ down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
913
+ A tuple of tensors that if specified are added to the residuals of down unet blocks.
914
+ mid_block_additional_residual: (`torch.Tensor`, *optional*):
915
+ A tensor that if specified is added to the residual of the middle unet block.
916
+ encoder_attention_mask (`torch.Tensor`):
917
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
918
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
919
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
920
+ return_dict (`bool`, *optional*, defaults to `True`):
921
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
922
+ tuple.
923
+ cross_attention_kwargs (`dict`, *optional*):
924
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
925
+ added_cond_kwargs: (`dict`, *optional*):
926
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
927
+ are passed along to the UNet blocks.
928
+ down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
929
+ additional residuals to be added to UNet long skip connections from down blocks to up blocks for
930
+ example from ControlNet side model(s)
931
+ mid_block_additional_residual (`torch.Tensor`, *optional*):
932
+ additional residual to be added to UNet mid block output, for example from ControlNet side model
933
+ down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
934
+ additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
935
+
936
+ Returns:
937
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
938
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
939
+ a `tuple` is returned where the first element is the sample tensor.
940
+ """
941
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
942
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
943
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
944
+ # on the fly if necessary.
945
+ default_overall_up_factor = 2**self.num_upsamplers
946
+
947
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
948
+ forward_upsample_size = False
949
+ upsample_size = None
950
+
951
+ for dim in sample.shape[-2:]:
952
+ if dim % default_overall_up_factor != 0:
953
+ # Forward upsample size to force interpolation output size.
954
+ forward_upsample_size = True
955
+ break
956
+
957
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
958
+ # expects mask of shape:
959
+ # [batch, key_tokens]
960
+ # adds singleton query_tokens dimension:
961
+ # [batch, 1, key_tokens]
962
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
963
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
964
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
965
+ if attention_mask is not None:
966
+ # assume that mask is expressed as:
967
+ # (1 = keep, 0 = discard)
968
+ # convert mask into a bias that can be added to attention scores:
969
+ # (keep = +0, discard = -10000.0)
970
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
971
+ attention_mask = attention_mask.unsqueeze(1)
972
+
973
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
974
+ if encoder_attention_mask is not None:
975
+ encoder_attention_mask = (
976
+ 1 - encoder_attention_mask.to(sample.dtype)
977
+ ) * -10000.0
978
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
979
+
980
+ # 0. center input if necessary
981
+ if self.config.center_input_sample:
982
+ sample = 2 * sample - 1.0
983
+
984
+ # 1. time
985
+ timesteps = timestep
986
+ if not torch.is_tensor(timesteps):
987
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
988
+ # This would be a good case for the `match` statement (Python 3.10+)
989
+ is_mps = sample.device.type == "mps"
990
+ if isinstance(timestep, float):
991
+ dtype = torch.float32 if is_mps else torch.float64
992
+ else:
993
+ dtype = torch.int32 if is_mps else torch.int64
994
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
995
+ elif len(timesteps.shape) == 0:
996
+ timesteps = timesteps[None].to(sample.device)
997
+
998
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
999
+ timesteps = timesteps.expand(sample.shape[0])
1000
+
1001
+ t_emb = self.time_proj(timesteps)
1002
+
1003
+ # `Timesteps` does not contain any weights and will always return f32 tensors
1004
+ # but time_embedding might actually be running in fp16. so we need to cast here.
1005
+ # there might be better ways to encapsulate this.
1006
+ t_emb = t_emb.to(dtype=sample.dtype)
1007
+
1008
+ emb = self.time_embedding(t_emb, timestep_cond)
1009
+ aug_emb = None
1010
+
1011
+ if self.class_embedding is not None:
1012
+ if class_labels is None:
1013
+ raise ValueError(
1014
+ "class_labels should be provided when num_class_embeds > 0"
1015
+ )
1016
+
1017
+ if self.config.class_embed_type == "timestep":
1018
+ class_labels = self.time_proj(class_labels)
1019
+
1020
+ # `Timesteps` does not contain any weights and will always return f32 tensors
1021
+ # there might be better ways to encapsulate this.
1022
+ class_labels = class_labels.to(dtype=sample.dtype)
1023
+
1024
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
1025
+
1026
+ if self.config.class_embeddings_concat:
1027
+ emb = torch.cat([emb, class_emb], dim=-1)
1028
+ else:
1029
+ emb = emb + class_emb
1030
+
1031
+ if self.config.addition_embed_type == "text":
1032
+ aug_emb = self.add_embedding(encoder_hidden_states)
1033
+ elif self.config.addition_embed_type == "text_image":
1034
+ # Kandinsky 2.1 - style
1035
+ if "image_embeds" not in added_cond_kwargs:
1036
+ raise ValueError(
1037
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1038
+ )
1039
+
1040
+ image_embs = added_cond_kwargs.get("image_embeds")
1041
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
1042
+ aug_emb = self.add_embedding(text_embs, image_embs)
1043
+ elif self.config.addition_embed_type == "text_time":
1044
+ # SDXL - style
1045
+ if "text_embeds" not in added_cond_kwargs:
1046
+ raise ValueError(
1047
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
1048
+ )
1049
+ text_embeds = added_cond_kwargs.get("text_embeds")
1050
+ if "time_ids" not in added_cond_kwargs:
1051
+ raise ValueError(
1052
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
1053
+ )
1054
+ time_ids = added_cond_kwargs.get("time_ids")
1055
+ time_embeds = self.add_time_proj(time_ids.flatten())
1056
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
1057
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
1058
+ add_embeds = add_embeds.to(emb.dtype)
1059
+ aug_emb = self.add_embedding(add_embeds)
1060
+ elif self.config.addition_embed_type == "image":
1061
+ # Kandinsky 2.2 - style
1062
+ if "image_embeds" not in added_cond_kwargs:
1063
+ raise ValueError(
1064
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1065
+ )
1066
+ image_embs = added_cond_kwargs.get("image_embeds")
1067
+ aug_emb = self.add_embedding(image_embs)
1068
+ elif self.config.addition_embed_type == "image_hint":
1069
+ # Kandinsky 2.2 - style
1070
+ if (
1071
+ "image_embeds" not in added_cond_kwargs
1072
+ or "hint" not in added_cond_kwargs
1073
+ ):
1074
+ raise ValueError(
1075
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
1076
+ )
1077
+ image_embs = added_cond_kwargs.get("image_embeds")
1078
+ hint = added_cond_kwargs.get("hint")
1079
+ aug_emb, hint = self.add_embedding(image_embs, hint)
1080
+ sample = torch.cat([sample, hint], dim=1)
1081
+
1082
+ emb = emb + aug_emb if aug_emb is not None else emb
1083
+
1084
+ if self.time_embed_act is not None:
1085
+ emb = self.time_embed_act(emb)
1086
+
1087
+ if (
1088
+ self.encoder_hid_proj is not None
1089
+ and self.config.encoder_hid_dim_type == "text_proj"
1090
+ ):
1091
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
1092
+ elif (
1093
+ self.encoder_hid_proj is not None
1094
+ and self.config.encoder_hid_dim_type == "text_image_proj"
1095
+ ):
1096
+ # Kadinsky 2.1 - style
1097
+ if "image_embeds" not in added_cond_kwargs:
1098
+ raise ValueError(
1099
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1100
+ )
1101
+
1102
+ image_embeds = added_cond_kwargs.get("image_embeds")
1103
+ encoder_hidden_states = self.encoder_hid_proj(
1104
+ encoder_hidden_states, image_embeds
1105
+ )
1106
+ elif (
1107
+ self.encoder_hid_proj is not None
1108
+ and self.config.encoder_hid_dim_type == "image_proj"
1109
+ ):
1110
+ # Kandinsky 2.2 - style
1111
+ if "image_embeds" not in added_cond_kwargs:
1112
+ raise ValueError(
1113
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1114
+ )
1115
+ image_embeds = added_cond_kwargs.get("image_embeds")
1116
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
1117
+ elif (
1118
+ self.encoder_hid_proj is not None
1119
+ and self.config.encoder_hid_dim_type == "ip_image_proj"
1120
+ ):
1121
+ if "image_embeds" not in added_cond_kwargs:
1122
+ raise ValueError(
1123
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1124
+ )
1125
+ image_embeds = added_cond_kwargs.get("image_embeds")
1126
+ image_embeds = self.encoder_hid_proj(image_embeds).to(
1127
+ encoder_hidden_states.dtype
1128
+ )
1129
+ encoder_hidden_states = torch.cat(
1130
+ [encoder_hidden_states, image_embeds], dim=1
1131
+ )
1132
+
1133
+ # 2. pre-process
1134
+ sample = self.conv_in(sample)
1135
+
1136
+ # # 2.5 GLIGEN position net
1137
+ # if (
1138
+ # cross_attention_kwargs is not None
1139
+ # and cross_attention_kwargs.get("gligen", None) is not None
1140
+ # ):
1141
+ # cross_attention_kwargs = cross_attention_kwargs.copy()
1142
+ # gligen_args = cross_attention_kwargs.pop("gligen")
1143
+ # cross_attention_kwargs["gligen"] = {
1144
+ # "objs": self.position_net(**gligen_args)
1145
+ # }
1146
+
1147
+ # 3. down
1148
+ lora_scale = (
1149
+ cross_attention_kwargs.get("scale", 1.0)
1150
+ if cross_attention_kwargs is not None
1151
+ else 1.0
1152
+ )
1153
+ if USE_PEFT_BACKEND:
1154
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
1155
+ scale_lora_layers(self, lora_scale)
1156
+
1157
+ is_controlnet = (
1158
+ mid_block_additional_residual is not None
1159
+ and down_block_additional_residuals is not None
1160
+ )
1161
+ # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
1162
+ is_adapter = down_intrablock_additional_residuals is not None
1163
+ # maintain backward compatibility for legacy usage, where
1164
+ # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
1165
+ # but can only use one or the other
1166
+ if (
1167
+ not is_adapter
1168
+ and mid_block_additional_residual is None
1169
+ and down_block_additional_residuals is not None
1170
+ ):
1171
+ deprecate(
1172
+ "T2I should not use down_block_additional_residuals",
1173
+ "1.3.0",
1174
+ "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
1175
+ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
1176
+ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
1177
+ standard_warn=False,
1178
+ )
1179
+ down_intrablock_additional_residuals = down_block_additional_residuals
1180
+ is_adapter = True
1181
+
1182
+ down_block_res_samples = (sample,)
1183
+ tot_referece_features = ()
1184
+ for downsample_block in self.down_blocks:
1185
+ if (
1186
+ hasattr(downsample_block, "has_cross_attention")
1187
+ and downsample_block.has_cross_attention
1188
+ ):
1189
+ # For t2i-adapter CrossAttnDownBlock2D
1190
+ additional_residuals = {}
1191
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1192
+ additional_residuals[
1193
+ "additional_residuals"
1194
+ ] = down_intrablock_additional_residuals.pop(0)
1195
+
1196
+ sample, res_samples = downsample_block(
1197
+ hidden_states=sample,
1198
+ temb=emb,
1199
+ encoder_hidden_states=encoder_hidden_states,
1200
+ attention_mask=attention_mask,
1201
+ cross_attention_kwargs=cross_attention_kwargs,
1202
+ encoder_attention_mask=encoder_attention_mask,
1203
+ **additional_residuals,
1204
+ )
1205
+ else:
1206
+ sample, res_samples = downsample_block(
1207
+ hidden_states=sample, temb=emb, scale=lora_scale
1208
+ )
1209
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1210
+ sample += down_intrablock_additional_residuals.pop(0)
1211
+
1212
+ down_block_res_samples += res_samples
1213
+
1214
+ if is_controlnet:
1215
+ new_down_block_res_samples = ()
1216
+
1217
+ for down_block_res_sample, down_block_additional_residual in zip(
1218
+ down_block_res_samples, down_block_additional_residuals
1219
+ ):
1220
+ down_block_res_sample = (
1221
+ down_block_res_sample + down_block_additional_residual
1222
+ )
1223
+ new_down_block_res_samples = new_down_block_res_samples + (
1224
+ down_block_res_sample,
1225
+ )
1226
+
1227
+ down_block_res_samples = new_down_block_res_samples
1228
+
1229
+ # 4. mid
1230
+ if self.mid_block is not None:
1231
+ if (
1232
+ hasattr(self.mid_block, "has_cross_attention")
1233
+ and self.mid_block.has_cross_attention
1234
+ ):
1235
+ sample = self.mid_block(
1236
+ sample,
1237
+ emb,
1238
+ encoder_hidden_states=encoder_hidden_states,
1239
+ attention_mask=attention_mask,
1240
+ cross_attention_kwargs=cross_attention_kwargs,
1241
+ encoder_attention_mask=encoder_attention_mask,
1242
+ )
1243
+ else:
1244
+ sample = self.mid_block(sample, emb)
1245
+
1246
+ # To support T2I-Adapter-XL
1247
+ if (
1248
+ is_adapter
1249
+ and len(down_intrablock_additional_residuals) > 0
1250
+ and sample.shape == down_intrablock_additional_residuals[0].shape
1251
+ ):
1252
+ sample += down_intrablock_additional_residuals.pop(0)
1253
+
1254
+ if is_controlnet:
1255
+ sample = sample + mid_block_additional_residual
1256
+
1257
+ # 5. up
1258
+ for i, upsample_block in enumerate(self.up_blocks):
1259
+ is_final_block = i == len(self.up_blocks) - 1
1260
+
1261
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1262
+ down_block_res_samples = down_block_res_samples[
1263
+ : -len(upsample_block.resnets)
1264
+ ]
1265
+
1266
+ # if we have not reached the final block and need to forward the
1267
+ # upsample size, we do it here
1268
+ if not is_final_block and forward_upsample_size:
1269
+ upsample_size = down_block_res_samples[-1].shape[2:]
1270
+
1271
+ if (
1272
+ hasattr(upsample_block, "has_cross_attention")
1273
+ and upsample_block.has_cross_attention
1274
+ ):
1275
+ sample = upsample_block(
1276
+ hidden_states=sample,
1277
+ temb=emb,
1278
+ res_hidden_states_tuple=res_samples,
1279
+ encoder_hidden_states=encoder_hidden_states,
1280
+ cross_attention_kwargs=cross_attention_kwargs,
1281
+ upsample_size=upsample_size,
1282
+ attention_mask=attention_mask,
1283
+ encoder_attention_mask=encoder_attention_mask,
1284
+ )
1285
+ else:
1286
+ sample = upsample_block(
1287
+ hidden_states=sample,
1288
+ temb=emb,
1289
+ res_hidden_states_tuple=res_samples,
1290
+ upsample_size=upsample_size,
1291
+ scale=lora_scale,
1292
+ )
1293
+
1294
+ # 6. post-process
1295
+ # if self.conv_norm_out:
1296
+ # sample = self.conv_norm_out(sample)
1297
+ # sample = self.conv_act(sample)
1298
+ # sample = self.conv_out(sample)
1299
+
1300
+ if USE_PEFT_BACKEND:
1301
+ # remove `lora_scale` from each PEFT layer
1302
+ unscale_lora_layers(self, lora_scale)
1303
+
1304
+ if not return_dict:
1305
+ return (sample,)
1306
+
1307
+ return UNet2DConditionOutput(sample=sample)
src/models/transformer_2d.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformer_2d.py
2
+ from dataclasses import dataclass
3
+ from typing import Any, Dict, Optional
4
+
5
+ import torch
6
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
7
+ # from diffusers.models.embeddings import CaptionProjection
8
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
9
+ from diffusers.models.modeling_utils import ModelMixin
10
+ from diffusers.models.normalization import AdaLayerNormSingle
11
+ from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version
12
+ from torch import nn
13
+
14
+ from .attention import BasicTransformerBlock
15
+
16
+
17
+ @dataclass
18
+ class Transformer2DModelOutput(BaseOutput):
19
+ """
20
+ The output of [`Transformer2DModel`].
21
+
22
+ Args:
23
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
24
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
25
+ distributions for the unnoised latent pixels.
26
+ """
27
+
28
+ sample: torch.FloatTensor
29
+ ref_feature: torch.FloatTensor
30
+
31
+
32
+ class Transformer2DModel(ModelMixin, ConfigMixin):
33
+ """
34
+ A 2D Transformer model for image-like data.
35
+
36
+ Parameters:
37
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
38
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
39
+ in_channels (`int`, *optional*):
40
+ The number of channels in the input and output (specify if the input is **continuous**).
41
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
42
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
43
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
44
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
45
+ This is fixed during training since it is used to learn a number of position embeddings.
46
+ num_vector_embeds (`int`, *optional*):
47
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
48
+ Includes the class for the masked latent pixel.
49
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
50
+ num_embeds_ada_norm ( `int`, *optional*):
51
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
52
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
53
+ added to the hidden states.
54
+
55
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
56
+ attention_bias (`bool`, *optional*):
57
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
58
+ """
59
+
60
+ _supports_gradient_checkpointing = True
61
+
62
+ @register_to_config
63
+ def __init__(
64
+ self,
65
+ num_attention_heads: int = 16,
66
+ attention_head_dim: int = 88,
67
+ in_channels: Optional[int] = None,
68
+ out_channels: Optional[int] = None,
69
+ num_layers: int = 1,
70
+ dropout: float = 0.0,
71
+ norm_num_groups: int = 32,
72
+ cross_attention_dim: Optional[int] = None,
73
+ attention_bias: bool = False,
74
+ sample_size: Optional[int] = None,
75
+ num_vector_embeds: Optional[int] = None,
76
+ patch_size: Optional[int] = None,
77
+ activation_fn: str = "geglu",
78
+ num_embeds_ada_norm: Optional[int] = None,
79
+ use_linear_projection: bool = False,
80
+ only_cross_attention: bool = False,
81
+ double_self_attention: bool = False,
82
+ upcast_attention: bool = False,
83
+ norm_type: str = "layer_norm",
84
+ norm_elementwise_affine: bool = True,
85
+ norm_eps: float = 1e-5,
86
+ attention_type: str = "default",
87
+ caption_channels: int = None,
88
+ ):
89
+ super().__init__()
90
+ self.use_linear_projection = use_linear_projection
91
+ self.num_attention_heads = num_attention_heads
92
+ self.attention_head_dim = attention_head_dim
93
+ inner_dim = num_attention_heads * attention_head_dim
94
+
95
+ conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
96
+ linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
97
+
98
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
99
+ # Define whether input is continuous or discrete depending on configuration
100
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
101
+ self.is_input_vectorized = num_vector_embeds is not None
102
+ self.is_input_patches = in_channels is not None and patch_size is not None
103
+
104
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
105
+ deprecation_message = (
106
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
107
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
108
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
109
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
110
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
111
+ )
112
+ deprecate(
113
+ "norm_type!=num_embeds_ada_norm",
114
+ "1.0.0",
115
+ deprecation_message,
116
+ standard_warn=False,
117
+ )
118
+ norm_type = "ada_norm"
119
+
120
+ if self.is_input_continuous and self.is_input_vectorized:
121
+ raise ValueError(
122
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
123
+ " sure that either `in_channels` or `num_vector_embeds` is None."
124
+ )
125
+ elif self.is_input_vectorized and self.is_input_patches:
126
+ raise ValueError(
127
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
128
+ " sure that either `num_vector_embeds` or `num_patches` is None."
129
+ )
130
+ elif (
131
+ not self.is_input_continuous
132
+ and not self.is_input_vectorized
133
+ and not self.is_input_patches
134
+ ):
135
+ raise ValueError(
136
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
137
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
138
+ )
139
+
140
+ # 2. Define input layers
141
+ self.in_channels = in_channels
142
+
143
+ self.norm = torch.nn.GroupNorm(
144
+ num_groups=norm_num_groups,
145
+ num_channels=in_channels,
146
+ eps=1e-6,
147
+ affine=True,
148
+ )
149
+ if use_linear_projection:
150
+ self.proj_in = linear_cls(in_channels, inner_dim)
151
+ else:
152
+ self.proj_in = conv_cls(
153
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
154
+ )
155
+
156
+ # 3. Define transformers blocks
157
+ self.transformer_blocks = nn.ModuleList(
158
+ [
159
+ BasicTransformerBlock(
160
+ inner_dim,
161
+ num_attention_heads,
162
+ attention_head_dim,
163
+ dropout=dropout,
164
+ cross_attention_dim=cross_attention_dim,
165
+ activation_fn=activation_fn,
166
+ num_embeds_ada_norm=num_embeds_ada_norm,
167
+ attention_bias=attention_bias,
168
+ only_cross_attention=only_cross_attention,
169
+ double_self_attention=double_self_attention,
170
+ upcast_attention=upcast_attention,
171
+ norm_type=norm_type,
172
+ norm_elementwise_affine=norm_elementwise_affine,
173
+ norm_eps=norm_eps,
174
+ attention_type=attention_type,
175
+ )
176
+ for d in range(num_layers)
177
+ ]
178
+ )
179
+
180
+ # 4. Define output layers
181
+ self.out_channels = in_channels if out_channels is None else out_channels
182
+ # TODO: should use out_channels for continuous projections
183
+ if use_linear_projection:
184
+ self.proj_out = linear_cls(inner_dim, in_channels)
185
+ else:
186
+ self.proj_out = conv_cls(
187
+ inner_dim, in_channels, kernel_size=1, stride=1, padding=0
188
+ )
189
+
190
+ # 5. PixArt-Alpha blocks.
191
+ self.adaln_single = None
192
+ self.use_additional_conditions = False
193
+ if norm_type == "ada_norm_single":
194
+ self.use_additional_conditions = self.config.sample_size == 128
195
+ # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
196
+ # additional conditions until we find better name
197
+ self.adaln_single = AdaLayerNormSingle(
198
+ inner_dim, use_additional_conditions=self.use_additional_conditions
199
+ )
200
+
201
+ self.caption_projection = None
202
+ # if caption_channels is not None:
203
+ # self.caption_projection = CaptionProjection(
204
+ # in_features=caption_channels, hidden_size=inner_dim
205
+ # )
206
+
207
+ self.gradient_checkpointing = False
208
+
209
+ def _set_gradient_checkpointing(self, module, value=False):
210
+ if hasattr(module, "gradient_checkpointing"):
211
+ module.gradient_checkpointing = value
212
+
213
+ def forward(
214
+ self,
215
+ hidden_states: torch.Tensor,
216
+ encoder_hidden_states: Optional[torch.Tensor] = None,
217
+ timestep: Optional[torch.LongTensor] = None,
218
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
219
+ class_labels: Optional[torch.LongTensor] = None,
220
+ cross_attention_kwargs: Dict[str, Any] = None,
221
+ attention_mask: Optional[torch.Tensor] = None,
222
+ encoder_attention_mask: Optional[torch.Tensor] = None,
223
+ return_dict: bool = True,
224
+ ):
225
+ """
226
+ The [`Transformer2DModel`] forward method.
227
+
228
+ Args:
229
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
230
+ Input `hidden_states`.
231
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
232
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
233
+ self-attention.
234
+ timestep ( `torch.LongTensor`, *optional*):
235
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
236
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
237
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
238
+ `AdaLayerZeroNorm`.
239
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
240
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
241
+ `self.processor` in
242
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
243
+ attention_mask ( `torch.Tensor`, *optional*):
244
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
245
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
246
+ negative values to the attention scores corresponding to "discard" tokens.
247
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
248
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
249
+
250
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
251
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
252
+
253
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
254
+ above. This bias will be added to the cross-attention scores.
255
+ return_dict (`bool`, *optional*, defaults to `True`):
256
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
257
+ tuple.
258
+
259
+ Returns:
260
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
261
+ `tuple` where the first element is the sample tensor.
262
+ """
263
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
264
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
265
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
266
+ # expects mask of shape:
267
+ # [batch, key_tokens]
268
+ # adds singleton query_tokens dimension:
269
+ # [batch, 1, key_tokens]
270
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
271
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
272
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
273
+ if attention_mask is not None and attention_mask.ndim == 2:
274
+ # assume that mask is expressed as:
275
+ # (1 = keep, 0 = discard)
276
+ # convert mask into a bias that can be added to attention scores:
277
+ # (keep = +0, discard = -10000.0)
278
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
279
+ attention_mask = attention_mask.unsqueeze(1)
280
+
281
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
282
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
283
+ encoder_attention_mask = (
284
+ 1 - encoder_attention_mask.to(hidden_states.dtype)
285
+ ) * -10000.0
286
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
287
+
288
+ # Retrieve lora scale.
289
+ lora_scale = (
290
+ cross_attention_kwargs.get("scale", 1.0)
291
+ if cross_attention_kwargs is not None
292
+ else 1.0
293
+ )
294
+
295
+ # 1. Input
296
+ batch, _, height, width = hidden_states.shape
297
+ residual = hidden_states
298
+
299
+ hidden_states = self.norm(hidden_states)
300
+ if not self.use_linear_projection:
301
+ hidden_states = (
302
+ self.proj_in(hidden_states, scale=lora_scale)
303
+ if not USE_PEFT_BACKEND
304
+ else self.proj_in(hidden_states)
305
+ )
306
+ inner_dim = hidden_states.shape[1]
307
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
308
+ batch, height * width, inner_dim
309
+ )
310
+ else:
311
+ inner_dim = hidden_states.shape[1]
312
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
313
+ batch, height * width, inner_dim
314
+ )
315
+ hidden_states = (
316
+ self.proj_in(hidden_states, scale=lora_scale)
317
+ if not USE_PEFT_BACKEND
318
+ else self.proj_in(hidden_states)
319
+ )
320
+
321
+ # 2. Blocks
322
+ if self.caption_projection is not None:
323
+ batch_size = hidden_states.shape[0]
324
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
325
+ encoder_hidden_states = encoder_hidden_states.view(
326
+ batch_size, -1, hidden_states.shape[-1]
327
+ )
328
+
329
+ ref_feature = hidden_states.reshape(batch, height, width, inner_dim)
330
+ for block in self.transformer_blocks:
331
+ if self.training and self.gradient_checkpointing:
332
+
333
+ def create_custom_forward(module, return_dict=None):
334
+ def custom_forward(*inputs):
335
+ if return_dict is not None:
336
+ return module(*inputs, return_dict=return_dict)
337
+ else:
338
+ return module(*inputs)
339
+
340
+ return custom_forward
341
+
342
+ ckpt_kwargs: Dict[str, Any] = (
343
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
344
+ )
345
+ hidden_states = torch.utils.checkpoint.checkpoint(
346
+ create_custom_forward(block),
347
+ hidden_states,
348
+ attention_mask,
349
+ encoder_hidden_states,
350
+ encoder_attention_mask,
351
+ timestep,
352
+ cross_attention_kwargs,
353
+ class_labels,
354
+ **ckpt_kwargs,
355
+ )
356
+ else:
357
+ hidden_states = block(
358
+ hidden_states,
359
+ attention_mask=attention_mask,
360
+ encoder_hidden_states=encoder_hidden_states,
361
+ encoder_attention_mask=encoder_attention_mask,
362
+ timestep=timestep,
363
+ cross_attention_kwargs=cross_attention_kwargs,
364
+ class_labels=class_labels,
365
+ )
366
+
367
+ # 3. Output
368
+ if self.is_input_continuous:
369
+ if not self.use_linear_projection:
370
+ hidden_states = (
371
+ hidden_states.reshape(batch, height, width, inner_dim)
372
+ .permute(0, 3, 1, 2)
373
+ .contiguous()
374
+ )
375
+ hidden_states = (
376
+ self.proj_out(hidden_states, scale=lora_scale)
377
+ if not USE_PEFT_BACKEND
378
+ else self.proj_out(hidden_states)
379
+ )
380
+ else:
381
+ hidden_states = (
382
+ self.proj_out(hidden_states, scale=lora_scale)
383
+ if not USE_PEFT_BACKEND
384
+ else self.proj_out(hidden_states)
385
+ )
386
+ hidden_states = (
387
+ hidden_states.reshape(batch, height, width, inner_dim)
388
+ .permute(0, 3, 1, 2)
389
+ .contiguous()
390
+ )
391
+
392
+ output = hidden_states + residual
393
+ if not return_dict:
394
+ return (output, ref_feature)
395
+
396
+ return Transformer2DModelOutput(sample=output, ref_feature=ref_feature)
src/models/unet_2d_blocks.py ADDED
@@ -0,0 +1,1131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
2
+ from typing import Any, Dict, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from diffusers.models.activations import get_activation
8
+ from diffusers.models.attention_processor import Attention
9
+ from diffusers.models.dual_transformer_2d import DualTransformer2DModel
10
+ from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
11
+ from diffusers.utils import is_torch_version, logging
12
+ from diffusers.utils.torch_utils import apply_freeu
13
+ from torch import nn
14
+
15
+ from .transformer_2d import Transformer2DModel
16
+
17
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
18
+
19
+
20
+ def get_down_block(
21
+ down_block_type: str,
22
+ num_layers: int,
23
+ in_channels: int,
24
+ out_channels: int,
25
+ temb_channels: int,
26
+ add_downsample: bool,
27
+ resnet_eps: float,
28
+ resnet_act_fn: str,
29
+ transformer_layers_per_block: int = 1,
30
+ num_attention_heads: Optional[int] = None,
31
+ resnet_groups: Optional[int] = None,
32
+ cross_attention_dim: Optional[int] = None,
33
+ downsample_padding: Optional[int] = None,
34
+ dual_cross_attention: bool = False,
35
+ use_linear_projection: bool = False,
36
+ only_cross_attention: bool = False,
37
+ upcast_attention: bool = False,
38
+ resnet_time_scale_shift: str = "default",
39
+ attention_type: str = "default",
40
+ resnet_skip_time_act: bool = False,
41
+ resnet_out_scale_factor: float = 1.0,
42
+ cross_attention_norm: Optional[str] = None,
43
+ attention_head_dim: Optional[int] = None,
44
+ downsample_type: Optional[str] = None,
45
+ dropout: float = 0.0,
46
+ ):
47
+ # If attn head dim is not defined, we default it to the number of heads
48
+ if attention_head_dim is None:
49
+ logger.warn(
50
+ f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
51
+ )
52
+ attention_head_dim = num_attention_heads
53
+
54
+ down_block_type = (
55
+ down_block_type[7:]
56
+ if down_block_type.startswith("UNetRes")
57
+ else down_block_type
58
+ )
59
+ if down_block_type == "DownBlock2D":
60
+ return DownBlock2D(
61
+ num_layers=num_layers,
62
+ in_channels=in_channels,
63
+ out_channels=out_channels,
64
+ temb_channels=temb_channels,
65
+ dropout=dropout,
66
+ add_downsample=add_downsample,
67
+ resnet_eps=resnet_eps,
68
+ resnet_act_fn=resnet_act_fn,
69
+ resnet_groups=resnet_groups,
70
+ downsample_padding=downsample_padding,
71
+ resnet_time_scale_shift=resnet_time_scale_shift,
72
+ )
73
+ elif down_block_type == "CrossAttnDownBlock2D":
74
+ if cross_attention_dim is None:
75
+ raise ValueError(
76
+ "cross_attention_dim must be specified for CrossAttnDownBlock2D"
77
+ )
78
+ return CrossAttnDownBlock2D(
79
+ num_layers=num_layers,
80
+ transformer_layers_per_block=transformer_layers_per_block,
81
+ in_channels=in_channels,
82
+ out_channels=out_channels,
83
+ temb_channels=temb_channels,
84
+ dropout=dropout,
85
+ add_downsample=add_downsample,
86
+ resnet_eps=resnet_eps,
87
+ resnet_act_fn=resnet_act_fn,
88
+ resnet_groups=resnet_groups,
89
+ downsample_padding=downsample_padding,
90
+ cross_attention_dim=cross_attention_dim,
91
+ num_attention_heads=num_attention_heads,
92
+ dual_cross_attention=dual_cross_attention,
93
+ use_linear_projection=use_linear_projection,
94
+ only_cross_attention=only_cross_attention,
95
+ upcast_attention=upcast_attention,
96
+ resnet_time_scale_shift=resnet_time_scale_shift,
97
+ attention_type=attention_type,
98
+ )
99
+ raise ValueError(f"{down_block_type} does not exist.")
100
+
101
+
102
+ def get_up_block(
103
+ up_block_type: str,
104
+ num_layers: int,
105
+ in_channels: int,
106
+ out_channels: int,
107
+ prev_output_channel: int,
108
+ temb_channels: int,
109
+ add_upsample: bool,
110
+ resnet_eps: float,
111
+ resnet_act_fn: str,
112
+ resolution_idx: Optional[int] = None,
113
+ transformer_layers_per_block: int = 1,
114
+ num_attention_heads: Optional[int] = None,
115
+ resnet_groups: Optional[int] = None,
116
+ cross_attention_dim: Optional[int] = None,
117
+ dual_cross_attention: bool = False,
118
+ use_linear_projection: bool = False,
119
+ only_cross_attention: bool = False,
120
+ upcast_attention: bool = False,
121
+ resnet_time_scale_shift: str = "default",
122
+ attention_type: str = "default",
123
+ resnet_skip_time_act: bool = False,
124
+ resnet_out_scale_factor: float = 1.0,
125
+ cross_attention_norm: Optional[str] = None,
126
+ attention_head_dim: Optional[int] = None,
127
+ upsample_type: Optional[str] = None,
128
+ dropout: float = 0.0,
129
+ ) -> nn.Module:
130
+ # If attn head dim is not defined, we default it to the number of heads
131
+ if attention_head_dim is None:
132
+ logger.warn(
133
+ f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
134
+ )
135
+ attention_head_dim = num_attention_heads
136
+
137
+ up_block_type = (
138
+ up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
139
+ )
140
+ if up_block_type == "UpBlock2D":
141
+ return UpBlock2D(
142
+ num_layers=num_layers,
143
+ in_channels=in_channels,
144
+ out_channels=out_channels,
145
+ prev_output_channel=prev_output_channel,
146
+ temb_channels=temb_channels,
147
+ resolution_idx=resolution_idx,
148
+ dropout=dropout,
149
+ add_upsample=add_upsample,
150
+ resnet_eps=resnet_eps,
151
+ resnet_act_fn=resnet_act_fn,
152
+ resnet_groups=resnet_groups,
153
+ resnet_time_scale_shift=resnet_time_scale_shift,
154
+ )
155
+ elif up_block_type == "CrossAttnUpBlock2D":
156
+ if cross_attention_dim is None:
157
+ raise ValueError(
158
+ "cross_attention_dim must be specified for CrossAttnUpBlock2D"
159
+ )
160
+ return CrossAttnUpBlock2D(
161
+ num_layers=num_layers,
162
+ transformer_layers_per_block=transformer_layers_per_block,
163
+ in_channels=in_channels,
164
+ out_channels=out_channels,
165
+ prev_output_channel=prev_output_channel,
166
+ temb_channels=temb_channels,
167
+ resolution_idx=resolution_idx,
168
+ dropout=dropout,
169
+ add_upsample=add_upsample,
170
+ resnet_eps=resnet_eps,
171
+ resnet_act_fn=resnet_act_fn,
172
+ resnet_groups=resnet_groups,
173
+ cross_attention_dim=cross_attention_dim,
174
+ num_attention_heads=num_attention_heads,
175
+ dual_cross_attention=dual_cross_attention,
176
+ use_linear_projection=use_linear_projection,
177
+ only_cross_attention=only_cross_attention,
178
+ upcast_attention=upcast_attention,
179
+ resnet_time_scale_shift=resnet_time_scale_shift,
180
+ attention_type=attention_type,
181
+ )
182
+
183
+ raise ValueError(f"{up_block_type} does not exist.")
184
+ def get_mid_block(
185
+ mid_block_type: str,
186
+ temb_channels: int,
187
+ in_channels: int,
188
+ resnet_eps: float,
189
+ resnet_act_fn: str,
190
+ resnet_groups: int,
191
+ output_scale_factor: float = 1.0,
192
+ transformer_layers_per_block: int = 1,
193
+ num_attention_heads: Optional[int] = None,
194
+ cross_attention_dim: Optional[int] = None,
195
+ dual_cross_attention: bool = False,
196
+ use_linear_projection: bool = False,
197
+ mid_block_only_cross_attention: bool = False,
198
+ upcast_attention: bool = False,
199
+ resnet_time_scale_shift: str = "default",
200
+ attention_type: str = "default",
201
+ resnet_skip_time_act: bool = False,
202
+ cross_attention_norm: Optional[str] = None,
203
+ attention_head_dim: Optional[int] = 1,
204
+ dropout: float = 0.0,
205
+ ):
206
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
207
+ return UNetMidBlock2DCrossAttn(
208
+ transformer_layers_per_block=transformer_layers_per_block,
209
+ in_channels=in_channels,
210
+ temb_channels=temb_channels,
211
+ dropout=dropout,
212
+ resnet_eps=resnet_eps,
213
+ resnet_act_fn=resnet_act_fn,
214
+ output_scale_factor=output_scale_factor,
215
+ resnet_time_scale_shift=resnet_time_scale_shift,
216
+ cross_attention_dim=cross_attention_dim,
217
+ num_attention_heads=num_attention_heads,
218
+ resnet_groups=resnet_groups,
219
+ dual_cross_attention=dual_cross_attention,
220
+ use_linear_projection=use_linear_projection,
221
+ upcast_attention=upcast_attention,
222
+ attention_type=attention_type,
223
+ )
224
+ elif mid_block_type == "UNetMidBlock2D":
225
+ return UNetMidBlock2D(
226
+ in_channels=in_channels,
227
+ temb_channels=temb_channels,
228
+ dropout=dropout,
229
+ num_layers=0,
230
+ resnet_eps=resnet_eps,
231
+ resnet_act_fn=resnet_act_fn,
232
+ output_scale_factor=output_scale_factor,
233
+ resnet_groups=resnet_groups,
234
+ resnet_time_scale_shift=resnet_time_scale_shift,
235
+ add_attention=False,
236
+ )
237
+ elif mid_block_type is None:
238
+ return None
239
+ else:
240
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
241
+
242
+
243
+ class AutoencoderTinyBlock(nn.Module):
244
+ """
245
+ Tiny Autoencoder block used in [`AutoencoderTiny`]. It is a mini residual module consisting of plain conv + ReLU
246
+ blocks.
247
+
248
+ Args:
249
+ in_channels (`int`): The number of input channels.
250
+ out_channels (`int`): The number of output channels.
251
+ act_fn (`str`):
252
+ ` The activation function to use. Supported values are `"swish"`, `"mish"`, `"gelu"`, and `"relu"`.
253
+
254
+ Returns:
255
+ `torch.FloatTensor`: A tensor with the same shape as the input tensor, but with the number of channels equal to
256
+ `out_channels`.
257
+ """
258
+
259
+ def __init__(self, in_channels: int, out_channels: int, act_fn: str):
260
+ super().__init__()
261
+ act_fn = get_activation(act_fn)
262
+ self.conv = nn.Sequential(
263
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
264
+ act_fn,
265
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
266
+ act_fn,
267
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
268
+ )
269
+ self.skip = (
270
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
271
+ if in_channels != out_channels
272
+ else nn.Identity()
273
+ )
274
+ self.fuse = nn.ReLU()
275
+
276
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
277
+ return self.fuse(self.conv(x) + self.skip(x))
278
+
279
+
280
+ class UNetMidBlock2D(nn.Module):
281
+ """
282
+ A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks.
283
+
284
+ Args:
285
+ in_channels (`int`): The number of input channels.
286
+ temb_channels (`int`): The number of temporal embedding channels.
287
+ dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
288
+ num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
289
+ resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
290
+ resnet_time_scale_shift (`str`, *optional*, defaults to `default`):
291
+ The type of normalization to apply to the time embeddings. This can help to improve the performance of the
292
+ model on tasks with long-range temporal dependencies.
293
+ resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks.
294
+ resnet_groups (`int`, *optional*, defaults to 32):
295
+ The number of groups to use in the group normalization layers of the resnet blocks.
296
+ attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks.
297
+ resnet_pre_norm (`bool`, *optional*, defaults to `True`):
298
+ Whether to use pre-normalization for the resnet blocks.
299
+ add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks.
300
+ attention_head_dim (`int`, *optional*, defaults to 1):
301
+ Dimension of a single attention head. The number of attention heads is determined based on this value and
302
+ the number of input channels.
303
+ output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
304
+
305
+ Returns:
306
+ `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
307
+ in_channels, height, width)`.
308
+
309
+ """
310
+
311
+ def __init__(
312
+ self,
313
+ in_channels: int,
314
+ temb_channels: int,
315
+ dropout: float = 0.0,
316
+ num_layers: int = 1,
317
+ resnet_eps: float = 1e-6,
318
+ resnet_time_scale_shift: str = "default", # default, spatial
319
+ resnet_act_fn: str = "swish",
320
+ resnet_groups: int = 32,
321
+ attn_groups: Optional[int] = None,
322
+ resnet_pre_norm: bool = True,
323
+ add_attention: bool = True,
324
+ attention_head_dim: int = 1,
325
+ output_scale_factor: float = 1.0,
326
+ ):
327
+ super().__init__()
328
+ resnet_groups = (
329
+ resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
330
+ )
331
+ self.add_attention = add_attention
332
+
333
+ if attn_groups is None:
334
+ attn_groups = (
335
+ resnet_groups if resnet_time_scale_shift == "default" else None
336
+ )
337
+
338
+ # there is always at least one resnet
339
+ resnets = [
340
+ ResnetBlock2D(
341
+ in_channels=in_channels,
342
+ out_channels=in_channels,
343
+ temb_channels=temb_channels,
344
+ eps=resnet_eps,
345
+ groups=resnet_groups,
346
+ dropout=dropout,
347
+ time_embedding_norm=resnet_time_scale_shift,
348
+ non_linearity=resnet_act_fn,
349
+ output_scale_factor=output_scale_factor,
350
+ pre_norm=resnet_pre_norm,
351
+ )
352
+ ]
353
+ attentions = []
354
+
355
+ if attention_head_dim is None:
356
+ logger.warn(
357
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
358
+ )
359
+ attention_head_dim = in_channels
360
+
361
+ for _ in range(num_layers):
362
+ if self.add_attention:
363
+ attentions.append(
364
+ Attention(
365
+ in_channels,
366
+ heads=in_channels // attention_head_dim,
367
+ dim_head=attention_head_dim,
368
+ rescale_output_factor=output_scale_factor,
369
+ eps=resnet_eps,
370
+ norm_num_groups=attn_groups,
371
+ spatial_norm_dim=temb_channels
372
+ if resnet_time_scale_shift == "spatial"
373
+ else None,
374
+ residual_connection=True,
375
+ bias=True,
376
+ upcast_softmax=True,
377
+ _from_deprecated_attn_block=True,
378
+ )
379
+ )
380
+ else:
381
+ attentions.append(None)
382
+
383
+ resnets.append(
384
+ ResnetBlock2D(
385
+ in_channels=in_channels,
386
+ out_channels=in_channels,
387
+ temb_channels=temb_channels,
388
+ eps=resnet_eps,
389
+ groups=resnet_groups,
390
+ dropout=dropout,
391
+ time_embedding_norm=resnet_time_scale_shift,
392
+ non_linearity=resnet_act_fn,
393
+ output_scale_factor=output_scale_factor,
394
+ pre_norm=resnet_pre_norm,
395
+ )
396
+ )
397
+
398
+ self.attentions = nn.ModuleList(attentions)
399
+ self.resnets = nn.ModuleList(resnets)
400
+
401
+ def forward(
402
+ self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None
403
+ ) -> torch.FloatTensor:
404
+ hidden_states = self.resnets[0](hidden_states, temb)
405
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
406
+ if attn is not None:
407
+ hidden_states = attn(hidden_states, temb=temb)
408
+ hidden_states = resnet(hidden_states, temb)
409
+
410
+ return hidden_states
411
+
412
+
413
+ class UNetMidBlock2DCrossAttn(nn.Module):
414
+ def __init__(
415
+ self,
416
+ in_channels: int,
417
+ temb_channels: int,
418
+ dropout: float = 0.0,
419
+ num_layers: int = 1,
420
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
421
+ resnet_eps: float = 1e-6,
422
+ resnet_time_scale_shift: str = "default",
423
+ resnet_act_fn: str = "swish",
424
+ resnet_groups: int = 32,
425
+ resnet_pre_norm: bool = True,
426
+ num_attention_heads: int = 1,
427
+ output_scale_factor: float = 1.0,
428
+ cross_attention_dim: int = 1280,
429
+ dual_cross_attention: bool = False,
430
+ use_linear_projection: bool = False,
431
+ upcast_attention: bool = False,
432
+ attention_type: str = "default",
433
+ ):
434
+ super().__init__()
435
+
436
+ self.has_cross_attention = True
437
+ self.num_attention_heads = num_attention_heads
438
+ resnet_groups = (
439
+ resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
440
+ )
441
+
442
+ # support for variable transformer layers per block
443
+ if isinstance(transformer_layers_per_block, int):
444
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
445
+
446
+ # there is always at least one resnet
447
+ resnets = [
448
+ ResnetBlock2D(
449
+ in_channels=in_channels,
450
+ out_channels=in_channels,
451
+ temb_channels=temb_channels,
452
+ eps=resnet_eps,
453
+ groups=resnet_groups,
454
+ dropout=dropout,
455
+ time_embedding_norm=resnet_time_scale_shift,
456
+ non_linearity=resnet_act_fn,
457
+ output_scale_factor=output_scale_factor,
458
+ pre_norm=resnet_pre_norm,
459
+ )
460
+ ]
461
+ attentions = []
462
+
463
+ for i in range(num_layers):
464
+ if not dual_cross_attention:
465
+ attentions.append(
466
+ Transformer2DModel(
467
+ num_attention_heads,
468
+ in_channels // num_attention_heads,
469
+ in_channels=in_channels,
470
+ num_layers=transformer_layers_per_block[i],
471
+ cross_attention_dim=cross_attention_dim,
472
+ norm_num_groups=resnet_groups,
473
+ use_linear_projection=use_linear_projection,
474
+ upcast_attention=upcast_attention,
475
+ attention_type=attention_type,
476
+ )
477
+ )
478
+ else:
479
+ attentions.append(
480
+ DualTransformer2DModel(
481
+ num_attention_heads,
482
+ in_channels // num_attention_heads,
483
+ in_channels=in_channels,
484
+ num_layers=1,
485
+ cross_attention_dim=cross_attention_dim,
486
+ norm_num_groups=resnet_groups,
487
+ )
488
+ )
489
+ resnets.append(
490
+ ResnetBlock2D(
491
+ in_channels=in_channels,
492
+ out_channels=in_channels,
493
+ temb_channels=temb_channels,
494
+ eps=resnet_eps,
495
+ groups=resnet_groups,
496
+ dropout=dropout,
497
+ time_embedding_norm=resnet_time_scale_shift,
498
+ non_linearity=resnet_act_fn,
499
+ output_scale_factor=output_scale_factor,
500
+ pre_norm=resnet_pre_norm,
501
+ )
502
+ )
503
+
504
+ self.attentions = nn.ModuleList(attentions)
505
+ self.resnets = nn.ModuleList(resnets)
506
+
507
+ self.gradient_checkpointing = False
508
+
509
+ def forward(
510
+ self,
511
+ hidden_states: torch.FloatTensor,
512
+ temb: Optional[torch.FloatTensor] = None,
513
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
514
+ attention_mask: Optional[torch.FloatTensor] = None,
515
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
516
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
517
+ ) -> torch.FloatTensor:
518
+ lora_scale = (
519
+ cross_attention_kwargs.get("scale", 1.0)
520
+ if cross_attention_kwargs is not None
521
+ else 1.0
522
+ )
523
+ hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
524
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
525
+ if self.training and self.gradient_checkpointing:
526
+
527
+ def create_custom_forward(module, return_dict=None):
528
+ def custom_forward(*inputs):
529
+ if return_dict is not None:
530
+ return module(*inputs, return_dict=return_dict)
531
+ else:
532
+ return module(*inputs)
533
+
534
+ return custom_forward
535
+
536
+ ckpt_kwargs: Dict[str, Any] = (
537
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
538
+ )
539
+ hidden_states, ref_feature = attn(
540
+ hidden_states,
541
+ encoder_hidden_states=encoder_hidden_states,
542
+ cross_attention_kwargs=cross_attention_kwargs,
543
+ attention_mask=attention_mask,
544
+ encoder_attention_mask=encoder_attention_mask,
545
+ return_dict=False,
546
+ )
547
+ hidden_states = torch.utils.checkpoint.checkpoint(
548
+ create_custom_forward(resnet),
549
+ hidden_states,
550
+ temb,
551
+ **ckpt_kwargs,
552
+ )
553
+ else:
554
+ hidden_states, ref_feature = attn(
555
+ hidden_states,
556
+ encoder_hidden_states=encoder_hidden_states,
557
+ cross_attention_kwargs=cross_attention_kwargs,
558
+ attention_mask=attention_mask,
559
+ encoder_attention_mask=encoder_attention_mask,
560
+ return_dict=False,
561
+ )
562
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
563
+
564
+ return hidden_states
565
+
566
+
567
+ class CrossAttnDownBlock2D(nn.Module):
568
+ def __init__(
569
+ self,
570
+ in_channels: int,
571
+ out_channels: int,
572
+ temb_channels: int,
573
+ dropout: float = 0.0,
574
+ num_layers: int = 1,
575
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
576
+ resnet_eps: float = 1e-6,
577
+ resnet_time_scale_shift: str = "default",
578
+ resnet_act_fn: str = "swish",
579
+ resnet_groups: int = 32,
580
+ resnet_pre_norm: bool = True,
581
+ num_attention_heads: int = 1,
582
+ cross_attention_dim: int = 1280,
583
+ output_scale_factor: float = 1.0,
584
+ downsample_padding: int = 1,
585
+ add_downsample: bool = True,
586
+ dual_cross_attention: bool = False,
587
+ use_linear_projection: bool = False,
588
+ only_cross_attention: bool = False,
589
+ upcast_attention: bool = False,
590
+ attention_type: str = "default",
591
+ ):
592
+ super().__init__()
593
+ resnets = []
594
+ attentions = []
595
+
596
+ self.has_cross_attention = True
597
+ self.num_attention_heads = num_attention_heads
598
+ if isinstance(transformer_layers_per_block, int):
599
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
600
+
601
+ for i in range(num_layers):
602
+ in_channels = in_channels if i == 0 else out_channels
603
+ resnets.append(
604
+ ResnetBlock2D(
605
+ in_channels=in_channels,
606
+ out_channels=out_channels,
607
+ temb_channels=temb_channels,
608
+ eps=resnet_eps,
609
+ groups=resnet_groups,
610
+ dropout=dropout,
611
+ time_embedding_norm=resnet_time_scale_shift,
612
+ non_linearity=resnet_act_fn,
613
+ output_scale_factor=output_scale_factor,
614
+ pre_norm=resnet_pre_norm,
615
+ )
616
+ )
617
+ if not dual_cross_attention:
618
+ attentions.append(
619
+ Transformer2DModel(
620
+ num_attention_heads,
621
+ out_channels // num_attention_heads,
622
+ in_channels=out_channels,
623
+ num_layers=transformer_layers_per_block[i],
624
+ cross_attention_dim=cross_attention_dim,
625
+ norm_num_groups=resnet_groups,
626
+ use_linear_projection=use_linear_projection,
627
+ only_cross_attention=only_cross_attention,
628
+ upcast_attention=upcast_attention,
629
+ attention_type=attention_type,
630
+ )
631
+ )
632
+ else:
633
+ attentions.append(
634
+ DualTransformer2DModel(
635
+ num_attention_heads,
636
+ out_channels // num_attention_heads,
637
+ in_channels=out_channels,
638
+ num_layers=1,
639
+ cross_attention_dim=cross_attention_dim,
640
+ norm_num_groups=resnet_groups,
641
+ )
642
+ )
643
+ self.attentions = nn.ModuleList(attentions)
644
+ self.resnets = nn.ModuleList(resnets)
645
+
646
+ if add_downsample:
647
+ self.downsamplers = nn.ModuleList(
648
+ [
649
+ Downsample2D(
650
+ out_channels,
651
+ use_conv=True,
652
+ out_channels=out_channels,
653
+ padding=downsample_padding,
654
+ name="op",
655
+ )
656
+ ]
657
+ )
658
+ else:
659
+ self.downsamplers = None
660
+
661
+ self.gradient_checkpointing = False
662
+
663
+ def forward(
664
+ self,
665
+ hidden_states: torch.FloatTensor,
666
+ temb: Optional[torch.FloatTensor] = None,
667
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
668
+ attention_mask: Optional[torch.FloatTensor] = None,
669
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
670
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
671
+ additional_residuals: Optional[torch.FloatTensor] = None,
672
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
673
+ output_states = ()
674
+
675
+ lora_scale = (
676
+ cross_attention_kwargs.get("scale", 1.0)
677
+ if cross_attention_kwargs is not None
678
+ else 1.0
679
+ )
680
+
681
+ blocks = list(zip(self.resnets, self.attentions))
682
+
683
+ for i, (resnet, attn) in enumerate(blocks):
684
+ if self.training and self.gradient_checkpointing:
685
+
686
+ def create_custom_forward(module, return_dict=None):
687
+ def custom_forward(*inputs):
688
+ if return_dict is not None:
689
+ return module(*inputs, return_dict=return_dict)
690
+ else:
691
+ return module(*inputs)
692
+
693
+ return custom_forward
694
+
695
+ ckpt_kwargs: Dict[str, Any] = (
696
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
697
+ )
698
+ hidden_states = torch.utils.checkpoint.checkpoint(
699
+ create_custom_forward(resnet),
700
+ hidden_states,
701
+ temb,
702
+ **ckpt_kwargs,
703
+ )
704
+ hidden_states, ref_feature = attn(
705
+ hidden_states,
706
+ encoder_hidden_states=encoder_hidden_states,
707
+ cross_attention_kwargs=cross_attention_kwargs,
708
+ attention_mask=attention_mask,
709
+ encoder_attention_mask=encoder_attention_mask,
710
+ return_dict=False,
711
+ )
712
+ else:
713
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
714
+ hidden_states, ref_feature = attn(
715
+ hidden_states,
716
+ encoder_hidden_states=encoder_hidden_states,
717
+ cross_attention_kwargs=cross_attention_kwargs,
718
+ attention_mask=attention_mask,
719
+ encoder_attention_mask=encoder_attention_mask,
720
+ return_dict=False,
721
+ )
722
+
723
+ # apply additional residuals to the output of the last pair of resnet and attention blocks
724
+ if i == len(blocks) - 1 and additional_residuals is not None:
725
+ hidden_states = hidden_states + additional_residuals
726
+
727
+ output_states = output_states + (hidden_states,)
728
+
729
+ if self.downsamplers is not None:
730
+ for downsampler in self.downsamplers:
731
+ hidden_states = downsampler(hidden_states, scale=lora_scale)
732
+
733
+ output_states = output_states + (hidden_states,)
734
+
735
+ return hidden_states, output_states
736
+
737
+
738
+ class DownBlock2D(nn.Module):
739
+ def __init__(
740
+ self,
741
+ in_channels: int,
742
+ out_channels: int,
743
+ temb_channels: int,
744
+ dropout: float = 0.0,
745
+ num_layers: int = 1,
746
+ resnet_eps: float = 1e-6,
747
+ resnet_time_scale_shift: str = "default",
748
+ resnet_act_fn: str = "swish",
749
+ resnet_groups: int = 32,
750
+ resnet_pre_norm: bool = True,
751
+ output_scale_factor: float = 1.0,
752
+ add_downsample: bool = True,
753
+ downsample_padding: int = 1,
754
+ ):
755
+ super().__init__()
756
+ resnets = []
757
+
758
+ for i in range(num_layers):
759
+ in_channels = in_channels if i == 0 else out_channels
760
+ resnets.append(
761
+ ResnetBlock2D(
762
+ in_channels=in_channels,
763
+ out_channels=out_channels,
764
+ temb_channels=temb_channels,
765
+ eps=resnet_eps,
766
+ groups=resnet_groups,
767
+ dropout=dropout,
768
+ time_embedding_norm=resnet_time_scale_shift,
769
+ non_linearity=resnet_act_fn,
770
+ output_scale_factor=output_scale_factor,
771
+ pre_norm=resnet_pre_norm,
772
+ )
773
+ )
774
+
775
+ self.resnets = nn.ModuleList(resnets)
776
+
777
+ if add_downsample:
778
+ self.downsamplers = nn.ModuleList(
779
+ [
780
+ Downsample2D(
781
+ out_channels,
782
+ use_conv=True,
783
+ out_channels=out_channels,
784
+ padding=downsample_padding,
785
+ name="op",
786
+ )
787
+ ]
788
+ )
789
+ else:
790
+ self.downsamplers = None
791
+
792
+ self.gradient_checkpointing = False
793
+
794
+ def forward(
795
+ self,
796
+ hidden_states: torch.FloatTensor,
797
+ temb: Optional[torch.FloatTensor] = None,
798
+ scale: float = 1.0,
799
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
800
+ output_states = ()
801
+
802
+ for resnet in self.resnets:
803
+ if self.training and self.gradient_checkpointing:
804
+
805
+ def create_custom_forward(module):
806
+ def custom_forward(*inputs):
807
+ return module(*inputs)
808
+
809
+ return custom_forward
810
+
811
+ if is_torch_version(">=", "1.11.0"):
812
+ hidden_states = torch.utils.checkpoint.checkpoint(
813
+ create_custom_forward(resnet),
814
+ hidden_states,
815
+ temb,
816
+ use_reentrant=False,
817
+ )
818
+ else:
819
+ hidden_states = torch.utils.checkpoint.checkpoint(
820
+ create_custom_forward(resnet), hidden_states, temb
821
+ )
822
+ else:
823
+ hidden_states = resnet(hidden_states, temb, scale=scale)
824
+
825
+ output_states = output_states + (hidden_states,)
826
+
827
+ if self.downsamplers is not None:
828
+ for downsampler in self.downsamplers:
829
+ hidden_states = downsampler(hidden_states, scale=scale)
830
+
831
+ output_states = output_states + (hidden_states,)
832
+
833
+ return hidden_states, output_states
834
+
835
+
836
+ class CrossAttnUpBlock2D(nn.Module):
837
+ def __init__(
838
+ self,
839
+ in_channels: int,
840
+ out_channels: int,
841
+ prev_output_channel: int,
842
+ temb_channels: int,
843
+ resolution_idx: Optional[int] = None,
844
+ dropout: float = 0.0,
845
+ num_layers: int = 1,
846
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
847
+ resnet_eps: float = 1e-6,
848
+ resnet_time_scale_shift: str = "default",
849
+ resnet_act_fn: str = "swish",
850
+ resnet_groups: int = 32,
851
+ resnet_pre_norm: bool = True,
852
+ num_attention_heads: int = 1,
853
+ cross_attention_dim: int = 1280,
854
+ output_scale_factor: float = 1.0,
855
+ add_upsample: bool = True,
856
+ dual_cross_attention: bool = False,
857
+ use_linear_projection: bool = False,
858
+ only_cross_attention: bool = False,
859
+ upcast_attention: bool = False,
860
+ attention_type: str = "default",
861
+ ):
862
+ super().__init__()
863
+ resnets = []
864
+ attentions = []
865
+
866
+ self.has_cross_attention = True
867
+ self.num_attention_heads = num_attention_heads
868
+
869
+ if isinstance(transformer_layers_per_block, int):
870
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
871
+
872
+ for i in range(num_layers):
873
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
874
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
875
+
876
+ resnets.append(
877
+ ResnetBlock2D(
878
+ in_channels=resnet_in_channels + res_skip_channels,
879
+ out_channels=out_channels,
880
+ temb_channels=temb_channels,
881
+ eps=resnet_eps,
882
+ groups=resnet_groups,
883
+ dropout=dropout,
884
+ time_embedding_norm=resnet_time_scale_shift,
885
+ non_linearity=resnet_act_fn,
886
+ output_scale_factor=output_scale_factor,
887
+ pre_norm=resnet_pre_norm,
888
+ )
889
+ )
890
+ if not dual_cross_attention:
891
+ attentions.append(
892
+ Transformer2DModel(
893
+ num_attention_heads,
894
+ out_channels // num_attention_heads,
895
+ in_channels=out_channels,
896
+ num_layers=transformer_layers_per_block[i],
897
+ cross_attention_dim=cross_attention_dim,
898
+ norm_num_groups=resnet_groups,
899
+ use_linear_projection=use_linear_projection,
900
+ only_cross_attention=only_cross_attention,
901
+ upcast_attention=upcast_attention,
902
+ attention_type=attention_type,
903
+ )
904
+ )
905
+ else:
906
+ attentions.append(
907
+ DualTransformer2DModel(
908
+ num_attention_heads,
909
+ out_channels // num_attention_heads,
910
+ in_channels=out_channels,
911
+ num_layers=1,
912
+ cross_attention_dim=cross_attention_dim,
913
+ norm_num_groups=resnet_groups,
914
+ )
915
+ )
916
+ self.attentions = nn.ModuleList(attentions)
917
+ self.resnets = nn.ModuleList(resnets)
918
+
919
+ if add_upsample:
920
+ self.upsamplers = nn.ModuleList(
921
+ [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]
922
+ )
923
+ else:
924
+ self.upsamplers = None
925
+
926
+ self.gradient_checkpointing = False
927
+ self.resolution_idx = resolution_idx
928
+
929
+ def forward(
930
+ self,
931
+ hidden_states: torch.FloatTensor,
932
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
933
+ temb: Optional[torch.FloatTensor] = None,
934
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
935
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
936
+ upsample_size: Optional[int] = None,
937
+ attention_mask: Optional[torch.FloatTensor] = None,
938
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
939
+ ) -> torch.FloatTensor:
940
+ lora_scale = (
941
+ cross_attention_kwargs.get("scale", 1.0)
942
+ if cross_attention_kwargs is not None
943
+ else 1.0
944
+ )
945
+ is_freeu_enabled = (
946
+ getattr(self, "s1", None)
947
+ and getattr(self, "s2", None)
948
+ and getattr(self, "b1", None)
949
+ and getattr(self, "b2", None)
950
+ )
951
+
952
+ for resnet, attn in zip(self.resnets, self.attentions):
953
+ # pop res hidden states
954
+ res_hidden_states = res_hidden_states_tuple[-1]
955
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
956
+
957
+ # FreeU: Only operate on the first two stages
958
+ if is_freeu_enabled:
959
+ hidden_states, res_hidden_states = apply_freeu(
960
+ self.resolution_idx,
961
+ hidden_states,
962
+ res_hidden_states,
963
+ s1=self.s1,
964
+ s2=self.s2,
965
+ b1=self.b1,
966
+ b2=self.b2,
967
+ )
968
+
969
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
970
+
971
+ if self.training and self.gradient_checkpointing:
972
+
973
+ def create_custom_forward(module, return_dict=None):
974
+ def custom_forward(*inputs):
975
+ if return_dict is not None:
976
+ return module(*inputs, return_dict=return_dict)
977
+ else:
978
+ return module(*inputs)
979
+
980
+ return custom_forward
981
+
982
+ ckpt_kwargs: Dict[str, Any] = (
983
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
984
+ )
985
+ hidden_states = torch.utils.checkpoint.checkpoint(
986
+ create_custom_forward(resnet),
987
+ hidden_states,
988
+ temb,
989
+ **ckpt_kwargs,
990
+ )
991
+ hidden_states, ref_feature = attn(
992
+ hidden_states,
993
+ encoder_hidden_states=encoder_hidden_states,
994
+ cross_attention_kwargs=cross_attention_kwargs,
995
+ attention_mask=attention_mask,
996
+ encoder_attention_mask=encoder_attention_mask,
997
+ return_dict=False,
998
+ )
999
+ else:
1000
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
1001
+ hidden_states, ref_feature = attn(
1002
+ hidden_states,
1003
+ encoder_hidden_states=encoder_hidden_states,
1004
+ cross_attention_kwargs=cross_attention_kwargs,
1005
+ attention_mask=attention_mask,
1006
+ encoder_attention_mask=encoder_attention_mask,
1007
+ return_dict=False,
1008
+ )
1009
+
1010
+ if self.upsamplers is not None:
1011
+ for upsampler in self.upsamplers:
1012
+ hidden_states = upsampler(
1013
+ hidden_states, upsample_size, scale=lora_scale
1014
+ )
1015
+
1016
+ return hidden_states
1017
+
1018
+
1019
+ class UpBlock2D(nn.Module):
1020
+ def __init__(
1021
+ self,
1022
+ in_channels: int,
1023
+ prev_output_channel: int,
1024
+ out_channels: int,
1025
+ temb_channels: int,
1026
+ resolution_idx: Optional[int] = None,
1027
+ dropout: float = 0.0,
1028
+ num_layers: int = 1,
1029
+ resnet_eps: float = 1e-6,
1030
+ resnet_time_scale_shift: str = "default",
1031
+ resnet_act_fn: str = "swish",
1032
+ resnet_groups: int = 32,
1033
+ resnet_pre_norm: bool = True,
1034
+ output_scale_factor: float = 1.0,
1035
+ add_upsample: bool = True,
1036
+ ):
1037
+ super().__init__()
1038
+ resnets = []
1039
+
1040
+ for i in range(num_layers):
1041
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1042
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1043
+
1044
+ resnets.append(
1045
+ ResnetBlock2D(
1046
+ in_channels=resnet_in_channels + res_skip_channels,
1047
+ out_channels=out_channels,
1048
+ temb_channels=temb_channels,
1049
+ eps=resnet_eps,
1050
+ groups=resnet_groups,
1051
+ dropout=dropout,
1052
+ time_embedding_norm=resnet_time_scale_shift,
1053
+ non_linearity=resnet_act_fn,
1054
+ output_scale_factor=output_scale_factor,
1055
+ pre_norm=resnet_pre_norm,
1056
+ )
1057
+ )
1058
+
1059
+ self.resnets = nn.ModuleList(resnets)
1060
+
1061
+ if add_upsample:
1062
+ self.upsamplers = nn.ModuleList(
1063
+ [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]
1064
+ )
1065
+ else:
1066
+ self.upsamplers = None
1067
+
1068
+ self.gradient_checkpointing = False
1069
+ self.resolution_idx = resolution_idx
1070
+
1071
+ def forward(
1072
+ self,
1073
+ hidden_states: torch.FloatTensor,
1074
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
1075
+ temb: Optional[torch.FloatTensor] = None,
1076
+ upsample_size: Optional[int] = None,
1077
+ scale: float = 1.0,
1078
+ ) -> torch.FloatTensor:
1079
+ is_freeu_enabled = (
1080
+ getattr(self, "s1", None)
1081
+ and getattr(self, "s2", None)
1082
+ and getattr(self, "b1", None)
1083
+ and getattr(self, "b2", None)
1084
+ )
1085
+
1086
+ for resnet in self.resnets:
1087
+ # pop res hidden states
1088
+ res_hidden_states = res_hidden_states_tuple[-1]
1089
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1090
+
1091
+ # FreeU: Only operate on the first two stages
1092
+ if is_freeu_enabled:
1093
+ hidden_states, res_hidden_states = apply_freeu(
1094
+ self.resolution_idx,
1095
+ hidden_states,
1096
+ res_hidden_states,
1097
+ s1=self.s1,
1098
+ s2=self.s2,
1099
+ b1=self.b1,
1100
+ b2=self.b2,
1101
+ )
1102
+
1103
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1104
+
1105
+ if self.training and self.gradient_checkpointing:
1106
+
1107
+ def create_custom_forward(module):
1108
+ def custom_forward(*inputs):
1109
+ return module(*inputs)
1110
+
1111
+ return custom_forward
1112
+
1113
+ if is_torch_version(">=", "1.11.0"):
1114
+ hidden_states = torch.utils.checkpoint.checkpoint(
1115
+ create_custom_forward(resnet),
1116
+ hidden_states,
1117
+ temb,
1118
+ use_reentrant=False,
1119
+ )
1120
+ else:
1121
+ hidden_states = torch.utils.checkpoint.checkpoint(
1122
+ create_custom_forward(resnet), hidden_states, temb
1123
+ )
1124
+ else:
1125
+ hidden_states = resnet(hidden_states, temb, scale=scale)
1126
+
1127
+ if self.upsamplers is not None:
1128
+ for upsampler in self.upsamplers:
1129
+ hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
1130
+
1131
+ return hidden_states
src/models/unet_2d_condition.py ADDED
@@ -0,0 +1,1305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
2
+ from dataclasses import dataclass
3
+ from typing import Any, Dict, List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.utils.checkpoint
8
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
9
+ from diffusers.loaders import UNet2DConditionLoadersMixin
10
+ from diffusers.models.activations import get_activation
11
+ from diffusers.models.attention_processor import (
12
+ ADDED_KV_ATTENTION_PROCESSORS,
13
+ CROSS_ATTENTION_PROCESSORS,
14
+ AttentionProcessor,
15
+ AttnAddedKVProcessor,
16
+ AttnProcessor,
17
+ )
18
+ from diffusers.models.embeddings import (
19
+ GaussianFourierProjection,
20
+ ImageHintTimeEmbedding,
21
+ ImageProjection,
22
+ ImageTimeEmbedding,
23
+ # PositionNet,
24
+ TextImageProjection,
25
+ TextImageTimeEmbedding,
26
+ TextTimeEmbedding,
27
+ TimestepEmbedding,
28
+ Timesteps,
29
+ )
30
+ from diffusers.models.modeling_utils import ModelMixin
31
+ from diffusers.utils import (
32
+ USE_PEFT_BACKEND,
33
+ BaseOutput,
34
+ deprecate,
35
+ logging,
36
+ scale_lora_layers,
37
+ unscale_lora_layers,
38
+ )
39
+
40
+ from .unet_2d_blocks import (
41
+ UNetMidBlock2D,
42
+ UNetMidBlock2DCrossAttn,
43
+ get_down_block,
44
+ get_up_block,
45
+ get_mid_block
46
+ )
47
+
48
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
49
+
50
+ @dataclass
51
+ class UNet2DConditionOutput(BaseOutput):
52
+ """
53
+ The output of [`UNet2DConditionModel`].
54
+
55
+ Args:
56
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
57
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
58
+ """
59
+
60
+ sample: torch.FloatTensor = None
61
+
62
+
63
+ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
64
+ r"""
65
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
66
+ shaped output.
67
+
68
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
69
+ for all models (such as downloading or saving).
70
+
71
+ Parameters:
72
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
73
+ Height and width of input/output sample.
74
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
75
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
76
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
77
+ flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
78
+ Whether to flip the sin to cos in the time embedding.
79
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
80
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
81
+ The tuple of downsample blocks to use.
82
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
83
+ Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
84
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
85
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
86
+ The tuple of upsample blocks to use.
87
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
88
+ Whether to include self-attention in the basic transformer blocks, see
89
+ [`~models.attention.BasicTransformerBlock`].
90
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
91
+ The tuple of output channels for each block.
92
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
93
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
94
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
95
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
96
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
97
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
98
+ If `None`, normalization and activation layers is skipped in post-processing.
99
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
100
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
101
+ The dimension of the cross attention features.
102
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
103
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
104
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
105
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
106
+ reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
107
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
108
+ blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
109
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
110
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
111
+ encoder_hid_dim (`int`, *optional*, defaults to None):
112
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
113
+ dimension to `cross_attention_dim`.
114
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
115
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
116
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
117
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
118
+ num_attention_heads (`int`, *optional*):
119
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
120
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
121
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
122
+ class_embed_type (`str`, *optional*, defaults to `None`):
123
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
124
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
125
+ addition_embed_type (`str`, *optional*, defaults to `None`):
126
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
127
+ "text". "text" will use the `TextTimeEmbedding` layer.
128
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
129
+ Dimension for the timestep embeddings.
130
+ num_class_embeds (`int`, *optional*, defaults to `None`):
131
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
132
+ class conditioning with `class_embed_type` equal to `None`.
133
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
134
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
135
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
136
+ An optional override for the dimension of the projected time embedding.
137
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
138
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
139
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
140
+ timestep_post_act (`str`, *optional*, defaults to `None`):
141
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
142
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
143
+ The dimension of `cond_proj` layer in the timestep embedding.
144
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
145
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
146
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
147
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
148
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
149
+ embeddings with the class embeddings.
150
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
151
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
152
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
153
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
154
+ otherwise.
155
+ """
156
+
157
+ _supports_gradient_checkpointing = True
158
+
159
+ @register_to_config
160
+ def __init__(
161
+ self,
162
+ sample_size: Optional[int] = None,
163
+ in_channels: int = 4,
164
+ out_channels: int = 4,
165
+ center_input_sample: bool = False,
166
+ flip_sin_to_cos: bool = True,
167
+ freq_shift: int = 0,
168
+ down_block_types: Tuple[str] = (
169
+ "CrossAttnDownBlock2D",
170
+ "CrossAttnDownBlock2D",
171
+ "CrossAttnDownBlock2D",
172
+ "DownBlock2D",
173
+ ),
174
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
175
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
176
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
177
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
178
+ layers_per_block: Union[int, Tuple[int]] = 2,
179
+ downsample_padding: int = 1,
180
+ mid_block_scale_factor: float = 1,
181
+ dropout: float = 0.0,
182
+ act_fn: str = "silu",
183
+ norm_num_groups: Optional[int] = 32,
184
+ norm_eps: float = 1e-5,
185
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
186
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
187
+ reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
188
+ encoder_hid_dim: Optional[int] = None,
189
+ encoder_hid_dim_type: Optional[str] = None,
190
+ attention_head_dim: Union[int, Tuple[int]] = 8,
191
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
192
+ dual_cross_attention: bool = False,
193
+ use_linear_projection: bool = False,
194
+ class_embed_type: Optional[str] = None,
195
+ addition_embed_type: Optional[str] = None,
196
+ addition_time_embed_dim: Optional[int] = None,
197
+ num_class_embeds: Optional[int] = None,
198
+ upcast_attention: bool = False,
199
+ resnet_time_scale_shift: str = "default",
200
+ resnet_skip_time_act: bool = False,
201
+ resnet_out_scale_factor: float = 1.0,
202
+ time_embedding_type: str = "positional",
203
+ time_embedding_dim: Optional[int] = None,
204
+ time_embedding_act_fn: Optional[str] = None,
205
+ timestep_post_act: Optional[str] = None,
206
+ time_cond_proj_dim: Optional[int] = None,
207
+ conv_in_kernel: int = 3,
208
+ conv_out_kernel: int = 3,
209
+ projection_class_embeddings_input_dim: Optional[int] = None,
210
+ attention_type: str = "default",
211
+ class_embeddings_concat: bool = False,
212
+ mid_block_only_cross_attention: Optional[bool] = None,
213
+ cross_attention_norm: Optional[str] = None,
214
+ addition_embed_type_num_heads: int = 64,
215
+ ):
216
+ super().__init__()
217
+
218
+ self.sample_size = sample_size
219
+
220
+ if num_attention_heads is not None:
221
+ raise ValueError(
222
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
223
+ )
224
+
225
+ # If `num_attention_heads` is not defined (which is the case for most models)
226
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
227
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
228
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
229
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
230
+ # which is why we correct for the naming here.
231
+ num_attention_heads = num_attention_heads or attention_head_dim
232
+
233
+ # Check inputs
234
+ self._check_config(
235
+ down_block_types=down_block_types,
236
+ up_block_types=up_block_types,
237
+ only_cross_attention=only_cross_attention,
238
+ block_out_channels=block_out_channels,
239
+ layers_per_block=layers_per_block,
240
+ cross_attention_dim=cross_attention_dim,
241
+ transformer_layers_per_block=transformer_layers_per_block,
242
+ reverse_transformer_layers_per_block=reverse_transformer_layers_per_block,
243
+ attention_head_dim=attention_head_dim,
244
+ num_attention_heads=num_attention_heads,
245
+ )
246
+
247
+ # input
248
+ conv_in_padding = (conv_in_kernel - 1) // 2
249
+ self.conv_in = nn.Conv2d(
250
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
251
+ )
252
+
253
+ # time
254
+ time_embed_dim, timestep_input_dim = self._set_time_proj(
255
+ time_embedding_type,
256
+ block_out_channels=block_out_channels,
257
+ flip_sin_to_cos=flip_sin_to_cos,
258
+ freq_shift=freq_shift,
259
+ time_embedding_dim=time_embedding_dim,
260
+ )
261
+
262
+ self.time_embedding = TimestepEmbedding(
263
+ timestep_input_dim,
264
+ time_embed_dim,
265
+ act_fn=act_fn,
266
+ post_act_fn=timestep_post_act,
267
+ cond_proj_dim=time_cond_proj_dim,
268
+ )
269
+
270
+ self._set_encoder_hid_proj(
271
+ encoder_hid_dim_type,
272
+ cross_attention_dim=cross_attention_dim,
273
+ encoder_hid_dim=encoder_hid_dim,
274
+ )
275
+
276
+ # class embedding
277
+ self._set_class_embedding(
278
+ class_embed_type,
279
+ act_fn=act_fn,
280
+ num_class_embeds=num_class_embeds,
281
+ projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
282
+ time_embed_dim=time_embed_dim,
283
+ timestep_input_dim=timestep_input_dim,
284
+ )
285
+
286
+ self._set_add_embedding(
287
+ addition_embed_type,
288
+ addition_embed_type_num_heads=addition_embed_type_num_heads,
289
+ addition_time_embed_dim=addition_time_embed_dim,
290
+ cross_attention_dim=cross_attention_dim,
291
+ encoder_hid_dim=encoder_hid_dim,
292
+ flip_sin_to_cos=flip_sin_to_cos,
293
+ freq_shift=freq_shift,
294
+ projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
295
+ time_embed_dim=time_embed_dim,
296
+ )
297
+
298
+ if time_embedding_act_fn is None:
299
+ self.time_embed_act = None
300
+ else:
301
+ self.time_embed_act = get_activation(time_embedding_act_fn)
302
+
303
+ self.down_blocks = nn.ModuleList([])
304
+ self.up_blocks = nn.ModuleList([])
305
+
306
+ if isinstance(only_cross_attention, bool):
307
+ if mid_block_only_cross_attention is None:
308
+ mid_block_only_cross_attention = only_cross_attention
309
+
310
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
311
+
312
+ if mid_block_only_cross_attention is None:
313
+ mid_block_only_cross_attention = False
314
+
315
+ if isinstance(num_attention_heads, int):
316
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
317
+
318
+ if isinstance(attention_head_dim, int):
319
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
320
+
321
+ if isinstance(cross_attention_dim, int):
322
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
323
+
324
+ if isinstance(layers_per_block, int):
325
+ layers_per_block = [layers_per_block] * len(down_block_types)
326
+
327
+ if isinstance(transformer_layers_per_block, int):
328
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
329
+
330
+ if class_embeddings_concat:
331
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
332
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
333
+ # regular time embeddings
334
+ blocks_time_embed_dim = time_embed_dim * 2
335
+ else:
336
+ blocks_time_embed_dim = time_embed_dim
337
+
338
+ # down
339
+ output_channel = block_out_channels[0]
340
+ for i, down_block_type in enumerate(down_block_types):
341
+ input_channel = output_channel
342
+ output_channel = block_out_channels[i]
343
+ is_final_block = i == len(block_out_channels) - 1
344
+
345
+ down_block = get_down_block(
346
+ down_block_type,
347
+ num_layers=layers_per_block[i],
348
+ transformer_layers_per_block=transformer_layers_per_block[i],
349
+ in_channels=input_channel,
350
+ out_channels=output_channel,
351
+ temb_channels=blocks_time_embed_dim,
352
+ add_downsample=not is_final_block,
353
+ resnet_eps=norm_eps,
354
+ resnet_act_fn=act_fn,
355
+ resnet_groups=norm_num_groups,
356
+ cross_attention_dim=cross_attention_dim[i],
357
+ num_attention_heads=num_attention_heads[i],
358
+ downsample_padding=downsample_padding,
359
+ dual_cross_attention=dual_cross_attention,
360
+ use_linear_projection=use_linear_projection,
361
+ only_cross_attention=only_cross_attention[i],
362
+ upcast_attention=upcast_attention,
363
+ resnet_time_scale_shift=resnet_time_scale_shift,
364
+ attention_type=attention_type,
365
+ resnet_skip_time_act=resnet_skip_time_act,
366
+ resnet_out_scale_factor=resnet_out_scale_factor,
367
+ cross_attention_norm=cross_attention_norm,
368
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
369
+ dropout=dropout,
370
+ )
371
+ self.down_blocks.append(down_block)
372
+
373
+ # mid
374
+ self.mid_block = get_mid_block(
375
+ mid_block_type,
376
+ temb_channels=blocks_time_embed_dim,
377
+ in_channels=block_out_channels[-1],
378
+ resnet_eps=norm_eps,
379
+ resnet_act_fn=act_fn,
380
+ resnet_groups=norm_num_groups,
381
+ output_scale_factor=mid_block_scale_factor,
382
+ transformer_layers_per_block=transformer_layers_per_block[-1],
383
+ num_attention_heads=num_attention_heads[-1],
384
+ cross_attention_dim=cross_attention_dim[-1],
385
+ dual_cross_attention=dual_cross_attention,
386
+ use_linear_projection=use_linear_projection,
387
+ mid_block_only_cross_attention=mid_block_only_cross_attention,
388
+ upcast_attention=upcast_attention,
389
+ resnet_time_scale_shift=resnet_time_scale_shift,
390
+ attention_type=attention_type,
391
+ resnet_skip_time_act=resnet_skip_time_act,
392
+ cross_attention_norm=cross_attention_norm,
393
+ attention_head_dim=attention_head_dim[-1],
394
+ dropout=dropout,
395
+ )
396
+
397
+ # count how many layers upsample the images
398
+ self.num_upsamplers = 0
399
+
400
+ # up
401
+ reversed_block_out_channels = list(reversed(block_out_channels))
402
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
403
+ reversed_layers_per_block = list(reversed(layers_per_block))
404
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
405
+ reversed_transformer_layers_per_block = (
406
+ list(reversed(transformer_layers_per_block))
407
+ if reverse_transformer_layers_per_block is None
408
+ else reverse_transformer_layers_per_block
409
+ )
410
+ only_cross_attention = list(reversed(only_cross_attention))
411
+
412
+ output_channel = reversed_block_out_channels[0]
413
+ for i, up_block_type in enumerate(up_block_types):
414
+ is_final_block = i == len(block_out_channels) - 1
415
+
416
+ prev_output_channel = output_channel
417
+ output_channel = reversed_block_out_channels[i]
418
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
419
+
420
+ # add upsample block for all BUT final layer
421
+ if not is_final_block:
422
+ add_upsample = True
423
+ self.num_upsamplers += 1
424
+ else:
425
+ add_upsample = False
426
+
427
+ up_block = get_up_block(
428
+ up_block_type,
429
+ num_layers=reversed_layers_per_block[i] + 1,
430
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
431
+ in_channels=input_channel,
432
+ out_channels=output_channel,
433
+ prev_output_channel=prev_output_channel,
434
+ temb_channels=blocks_time_embed_dim,
435
+ add_upsample=add_upsample,
436
+ resnet_eps=norm_eps,
437
+ resnet_act_fn=act_fn,
438
+ resolution_idx=i,
439
+ resnet_groups=norm_num_groups,
440
+ cross_attention_dim=reversed_cross_attention_dim[i],
441
+ num_attention_heads=reversed_num_attention_heads[i],
442
+ dual_cross_attention=dual_cross_attention,
443
+ use_linear_projection=use_linear_projection,
444
+ only_cross_attention=only_cross_attention[i],
445
+ upcast_attention=upcast_attention,
446
+ resnet_time_scale_shift=resnet_time_scale_shift,
447
+ attention_type=attention_type,
448
+ resnet_skip_time_act=resnet_skip_time_act,
449
+ resnet_out_scale_factor=resnet_out_scale_factor,
450
+ cross_attention_norm=cross_attention_norm,
451
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
452
+ dropout=dropout,
453
+ )
454
+ self.up_blocks.append(up_block)
455
+ prev_output_channel = output_channel
456
+
457
+ # out
458
+ if norm_num_groups is not None:
459
+ self.conv_norm_out = nn.GroupNorm(
460
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
461
+ )
462
+
463
+ self.conv_act = get_activation(act_fn)
464
+
465
+ else:
466
+ self.conv_norm_out = None
467
+ self.conv_act = None
468
+
469
+ conv_out_padding = (conv_out_kernel - 1) // 2
470
+ self.conv_out = nn.Conv2d(
471
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
472
+ )
473
+
474
+ self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim)
475
+
476
+ def _check_config(
477
+ self,
478
+ down_block_types: Tuple[str],
479
+ up_block_types: Tuple[str],
480
+ only_cross_attention: Union[bool, Tuple[bool]],
481
+ block_out_channels: Tuple[int],
482
+ layers_per_block: Union[int, Tuple[int]],
483
+ cross_attention_dim: Union[int, Tuple[int]],
484
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
485
+ reverse_transformer_layers_per_block: bool,
486
+ attention_head_dim: int,
487
+ num_attention_heads: Optional[Union[int, Tuple[int]]],
488
+ ):
489
+ if len(down_block_types) != len(up_block_types):
490
+ raise ValueError(
491
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
492
+ )
493
+
494
+ if len(block_out_channels) != len(down_block_types):
495
+ raise ValueError(
496
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
497
+ )
498
+
499
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
500
+ raise ValueError(
501
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
502
+ )
503
+
504
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
505
+ raise ValueError(
506
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
507
+ )
508
+
509
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
510
+ raise ValueError(
511
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
512
+ )
513
+
514
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
515
+ raise ValueError(
516
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
517
+ )
518
+
519
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
520
+ raise ValueError(
521
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
522
+ )
523
+ if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
524
+ for layer_number_per_block in transformer_layers_per_block:
525
+ if isinstance(layer_number_per_block, list):
526
+ raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
527
+
528
+ def _set_time_proj(
529
+ self,
530
+ time_embedding_type: str,
531
+ block_out_channels: int,
532
+ flip_sin_to_cos: bool,
533
+ freq_shift: float,
534
+ time_embedding_dim: int,
535
+ ) -> Tuple[int, int]:
536
+ if time_embedding_type == "fourier":
537
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
538
+ if time_embed_dim % 2 != 0:
539
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
540
+ self.time_proj = GaussianFourierProjection(
541
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
542
+ )
543
+ timestep_input_dim = time_embed_dim
544
+ elif time_embedding_type == "positional":
545
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
546
+
547
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
548
+ timestep_input_dim = block_out_channels[0]
549
+ else:
550
+ raise ValueError(
551
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
552
+ )
553
+
554
+ return time_embed_dim, timestep_input_dim
555
+
556
+ def _set_encoder_hid_proj(
557
+ self,
558
+ encoder_hid_dim_type: Optional[str],
559
+ cross_attention_dim: Union[int, Tuple[int]],
560
+ encoder_hid_dim: Optional[int],
561
+ ):
562
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
563
+ encoder_hid_dim_type = "text_proj"
564
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
565
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
566
+
567
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
568
+ raise ValueError(
569
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
570
+ )
571
+
572
+ if encoder_hid_dim_type == "text_proj":
573
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
574
+ elif encoder_hid_dim_type == "text_image_proj":
575
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
576
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
577
+ # case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)`
578
+ self.encoder_hid_proj = TextImageProjection(
579
+ text_embed_dim=encoder_hid_dim,
580
+ image_embed_dim=cross_attention_dim,
581
+ cross_attention_dim=cross_attention_dim,
582
+ )
583
+ elif encoder_hid_dim_type == "image_proj":
584
+ # Kandinsky 2.2
585
+ self.encoder_hid_proj = ImageProjection(
586
+ image_embed_dim=encoder_hid_dim,
587
+ cross_attention_dim=cross_attention_dim,
588
+ )
589
+ elif encoder_hid_dim_type is not None:
590
+ raise ValueError(
591
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
592
+ )
593
+ else:
594
+ self.encoder_hid_proj = None
595
+
596
+ def _set_class_embedding(
597
+ self,
598
+ class_embed_type: Optional[str],
599
+ act_fn: str,
600
+ num_class_embeds: Optional[int],
601
+ projection_class_embeddings_input_dim: Optional[int],
602
+ time_embed_dim: int,
603
+ timestep_input_dim: int,
604
+ ):
605
+ if class_embed_type is None and num_class_embeds is not None:
606
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
607
+ elif class_embed_type == "timestep":
608
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
609
+ elif class_embed_type == "identity":
610
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
611
+ elif class_embed_type == "projection":
612
+ if projection_class_embeddings_input_dim is None:
613
+ raise ValueError(
614
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
615
+ )
616
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
617
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
618
+ # 2. it projects from an arbitrary input dimension.
619
+ #
620
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
621
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
622
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
623
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
624
+ elif class_embed_type == "simple_projection":
625
+ if projection_class_embeddings_input_dim is None:
626
+ raise ValueError(
627
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
628
+ )
629
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
630
+ else:
631
+ self.class_embedding = None
632
+
633
+ def _set_add_embedding(
634
+ self,
635
+ addition_embed_type: str,
636
+ addition_embed_type_num_heads: int,
637
+ addition_time_embed_dim: Optional[int],
638
+ flip_sin_to_cos: bool,
639
+ freq_shift: float,
640
+ cross_attention_dim: Optional[int],
641
+ encoder_hid_dim: Optional[int],
642
+ projection_class_embeddings_input_dim: Optional[int],
643
+ time_embed_dim: int,
644
+ ):
645
+ if addition_embed_type == "text":
646
+ if encoder_hid_dim is not None:
647
+ text_time_embedding_from_dim = encoder_hid_dim
648
+ else:
649
+ text_time_embedding_from_dim = cross_attention_dim
650
+
651
+ self.add_embedding = TextTimeEmbedding(
652
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
653
+ )
654
+ elif addition_embed_type == "text_image":
655
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
656
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
657
+ # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)`
658
+ self.add_embedding = TextImageTimeEmbedding(
659
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
660
+ )
661
+ elif addition_embed_type == "text_time":
662
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
663
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
664
+ elif addition_embed_type == "image":
665
+ # Kandinsky 2.2
666
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
667
+ elif addition_embed_type == "image_hint":
668
+ # Kandinsky 2.2 ControlNet
669
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
670
+ elif addition_embed_type is not None:
671
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
672
+
673
+ def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int):
674
+ if attention_type in ["gated", "gated-text-image"]:
675
+ positive_len = 768
676
+ if isinstance(cross_attention_dim, int):
677
+ positive_len = cross_attention_dim
678
+ elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
679
+ positive_len = cross_attention_dim[0]
680
+
681
+ feature_type = "text-only" if attention_type == "gated" else "text-image"
682
+ self.position_net = GLIGENTextBoundingboxProjection(
683
+ positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
684
+ )
685
+
686
+ @property
687
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
688
+ r"""
689
+ Returns:
690
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
691
+ indexed by its weight name.
692
+ """
693
+ # set recursively
694
+ processors = {}
695
+
696
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
697
+ if hasattr(module, "get_processor"):
698
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
699
+
700
+ for sub_name, child in module.named_children():
701
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
702
+
703
+ return processors
704
+
705
+ for name, module in self.named_children():
706
+ fn_recursive_add_processors(name, module, processors)
707
+
708
+ return processors
709
+
710
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
711
+ r"""
712
+ Sets the attention processor to use to compute attention.
713
+
714
+ Parameters:
715
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
716
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
717
+ for **all** `Attention` layers.
718
+
719
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
720
+ processor. This is strongly recommended when setting trainable attention processors.
721
+
722
+ """
723
+ count = len(self.attn_processors.keys())
724
+
725
+ if isinstance(processor, dict) and len(processor) != count:
726
+ raise ValueError(
727
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
728
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
729
+ )
730
+
731
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
732
+ if hasattr(module, "set_processor"):
733
+ if not isinstance(processor, dict):
734
+ module.set_processor(processor)
735
+ else:
736
+ module.set_processor(processor.pop(f"{name}.processor"))
737
+
738
+ for sub_name, child in module.named_children():
739
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
740
+
741
+ for name, module in self.named_children():
742
+ fn_recursive_attn_processor(name, module, processor)
743
+
744
+ def set_default_attn_processor(self):
745
+ """
746
+ Disables custom attention processors and sets the default attention implementation.
747
+ """
748
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
749
+ processor = AttnAddedKVProcessor()
750
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
751
+ processor = AttnProcessor()
752
+ else:
753
+ raise ValueError(
754
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
755
+ )
756
+
757
+ self.set_attn_processor(processor)
758
+
759
+ def set_attention_slice(self, slice_size: Union[str, int, List[int]] = "auto"):
760
+ r"""
761
+ Enable sliced attention computation.
762
+
763
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
764
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
765
+
766
+ Args:
767
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
768
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
769
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
770
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
771
+ must be a multiple of `slice_size`.
772
+ """
773
+ sliceable_head_dims = []
774
+
775
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
776
+ if hasattr(module, "set_attention_slice"):
777
+ sliceable_head_dims.append(module.sliceable_head_dim)
778
+
779
+ for child in module.children():
780
+ fn_recursive_retrieve_sliceable_dims(child)
781
+
782
+ # retrieve number of attention layers
783
+ for module in self.children():
784
+ fn_recursive_retrieve_sliceable_dims(module)
785
+
786
+ num_sliceable_layers = len(sliceable_head_dims)
787
+
788
+ if slice_size == "auto":
789
+ # half the attention head size is usually a good trade-off between
790
+ # speed and memory
791
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
792
+ elif slice_size == "max":
793
+ # make smallest slice possible
794
+ slice_size = num_sliceable_layers * [1]
795
+
796
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
797
+
798
+ if len(slice_size) != len(sliceable_head_dims):
799
+ raise ValueError(
800
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
801
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
802
+ )
803
+
804
+ for i in range(len(slice_size)):
805
+ size = slice_size[i]
806
+ dim = sliceable_head_dims[i]
807
+ if size is not None and size > dim:
808
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
809
+
810
+ # Recursively walk through all the children.
811
+ # Any children which exposes the set_attention_slice method
812
+ # gets the message
813
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
814
+ if hasattr(module, "set_attention_slice"):
815
+ module.set_attention_slice(slice_size.pop())
816
+
817
+ for child in module.children():
818
+ fn_recursive_set_attention_slice(child, slice_size)
819
+
820
+ reversed_slice_size = list(reversed(slice_size))
821
+ for module in self.children():
822
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
823
+
824
+ def _set_gradient_checkpointing(self, module, value=False):
825
+ if hasattr(module, "gradient_checkpointing"):
826
+ module.gradient_checkpointing = value
827
+
828
+ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
829
+ r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
830
+
831
+ The suffixes after the scaling factors represent the stage blocks where they are being applied.
832
+
833
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
834
+ are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
835
+
836
+ Args:
837
+ s1 (`float`):
838
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
839
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
840
+ s2 (`float`):
841
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
842
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
843
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
844
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
845
+ """
846
+ for i, upsample_block in enumerate(self.up_blocks):
847
+ setattr(upsample_block, "s1", s1)
848
+ setattr(upsample_block, "s2", s2)
849
+ setattr(upsample_block, "b1", b1)
850
+ setattr(upsample_block, "b2", b2)
851
+
852
+ def disable_freeu(self):
853
+ """Disables the FreeU mechanism."""
854
+ freeu_keys = {"s1", "s2", "b1", "b2"}
855
+ for i, upsample_block in enumerate(self.up_blocks):
856
+ for k in freeu_keys:
857
+ if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
858
+ setattr(upsample_block, k, None)
859
+
860
+ def fuse_qkv_projections(self):
861
+ """
862
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
863
+ are fused. For cross-attention modules, key and value projection matrices are fused.
864
+
865
+ <Tip warning={true}>
866
+
867
+ This API is 🧪 experimental.
868
+
869
+ </Tip>
870
+ """
871
+ self.original_attn_processors = None
872
+
873
+ for _, attn_processor in self.attn_processors.items():
874
+ if "Added" in str(attn_processor.__class__.__name__):
875
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
876
+
877
+ self.original_attn_processors = self.attn_processors
878
+
879
+ for module in self.modules():
880
+ if isinstance(module, Attention):
881
+ module.fuse_projections(fuse=True)
882
+
883
+ def unfuse_qkv_projections(self):
884
+ """Disables the fused QKV projection if enabled.
885
+
886
+ <Tip warning={true}>
887
+
888
+ This API is 🧪 experimental.
889
+
890
+ </Tip>
891
+
892
+ """
893
+ if self.original_attn_processors is not None:
894
+ self.set_attn_processor(self.original_attn_processors)
895
+
896
+ def unload_lora(self):
897
+ """Unloads LoRA weights."""
898
+ deprecate(
899
+ "unload_lora",
900
+ "0.28.0",
901
+ "Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters().",
902
+ )
903
+ for module in self.modules():
904
+ if hasattr(module, "set_lora_layer"):
905
+ module.set_lora_layer(None)
906
+
907
+ def get_time_embed(
908
+ self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int]
909
+ ) -> Optional[torch.Tensor]:
910
+ timesteps = timestep
911
+ if not torch.is_tensor(timesteps):
912
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
913
+ # This would be a good case for the `match` statement (Python 3.10+)
914
+ is_mps = sample.device.type == "mps"
915
+ if isinstance(timestep, float):
916
+ dtype = torch.float32 if is_mps else torch.float64
917
+ else:
918
+ dtype = torch.int32 if is_mps else torch.int64
919
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
920
+ elif len(timesteps.shape) == 0:
921
+ timesteps = timesteps[None].to(sample.device)
922
+
923
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
924
+ timesteps = timesteps.expand(sample.shape[0])
925
+
926
+ t_emb = self.time_proj(timesteps)
927
+ # `Timesteps` does not contain any weights and will always return f32 tensors
928
+ # but time_embedding might actually be running in fp16. so we need to cast here.
929
+ # there might be better ways to encapsulate this.
930
+ t_emb = t_emb.to(dtype=sample.dtype)
931
+ return t_emb
932
+
933
+ def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
934
+ class_emb = None
935
+ if self.class_embedding is not None:
936
+ if class_labels is None:
937
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
938
+
939
+ if self.config.class_embed_type == "timestep":
940
+ class_labels = self.time_proj(class_labels)
941
+
942
+ # `Timesteps` does not contain any weights and will always return f32 tensors
943
+ # there might be better ways to encapsulate this.
944
+ class_labels = class_labels.to(dtype=sample.dtype)
945
+
946
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
947
+ return class_emb
948
+
949
+ def get_aug_embed(
950
+ self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
951
+ ) -> Optional[torch.Tensor]:
952
+ aug_emb = None
953
+ if self.config.addition_embed_type == "text":
954
+ aug_emb = self.add_embedding(encoder_hidden_states)
955
+ elif self.config.addition_embed_type == "text_image":
956
+ # Kandinsky 2.1 - style
957
+ if "image_embeds" not in added_cond_kwargs:
958
+ raise ValueError(
959
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
960
+ )
961
+
962
+ image_embs = added_cond_kwargs.get("image_embeds")
963
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
964
+ aug_emb = self.add_embedding(text_embs, image_embs)
965
+ elif self.config.addition_embed_type == "text_time":
966
+ # SDXL - style
967
+ if "text_embeds" not in added_cond_kwargs:
968
+ raise ValueError(
969
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
970
+ )
971
+ text_embeds = added_cond_kwargs.get("text_embeds")
972
+ if "time_ids" not in added_cond_kwargs:
973
+ raise ValueError(
974
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
975
+ )
976
+ time_ids = added_cond_kwargs.get("time_ids")
977
+ time_embeds = self.add_time_proj(time_ids.flatten())
978
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
979
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
980
+ add_embeds = add_embeds.to(emb.dtype)
981
+ aug_emb = self.add_embedding(add_embeds)
982
+ elif self.config.addition_embed_type == "image":
983
+ # Kandinsky 2.2 - style
984
+ if "image_embeds" not in added_cond_kwargs:
985
+ raise ValueError(
986
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
987
+ )
988
+ image_embs = added_cond_kwargs.get("image_embeds")
989
+ aug_emb = self.add_embedding(image_embs)
990
+ elif self.config.addition_embed_type == "image_hint":
991
+ # Kandinsky 2.2 - style
992
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
993
+ raise ValueError(
994
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
995
+ )
996
+ image_embs = added_cond_kwargs.get("image_embeds")
997
+ hint = added_cond_kwargs.get("hint")
998
+ aug_emb = self.add_embedding(image_embs, hint)
999
+ return aug_emb
1000
+
1001
+ def process_encoder_hidden_states(
1002
+ self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
1003
+ ) -> torch.Tensor:
1004
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
1005
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
1006
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
1007
+ # Kandinsky 2.1 - style
1008
+ if "image_embeds" not in added_cond_kwargs:
1009
+ raise ValueError(
1010
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1011
+ )
1012
+
1013
+ image_embeds = added_cond_kwargs.get("image_embeds")
1014
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
1015
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
1016
+ # Kandinsky 2.2 - style
1017
+ if "image_embeds" not in added_cond_kwargs:
1018
+ raise ValueError(
1019
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1020
+ )
1021
+ image_embeds = added_cond_kwargs.get("image_embeds")
1022
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
1023
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
1024
+ if "image_embeds" not in added_cond_kwargs:
1025
+ raise ValueError(
1026
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1027
+ )
1028
+ image_embeds = added_cond_kwargs.get("image_embeds")
1029
+ image_embeds = self.encoder_hid_proj(image_embeds)
1030
+ encoder_hidden_states = (encoder_hidden_states, image_embeds)
1031
+ return encoder_hidden_states
1032
+
1033
+ def forward(
1034
+ self,
1035
+ sample: torch.FloatTensor,
1036
+ timestep: Union[torch.Tensor, float, int],
1037
+ encoder_hidden_states: torch.Tensor,
1038
+ class_labels: Optional[torch.Tensor] = None,
1039
+ timestep_cond: Optional[torch.Tensor] = None,
1040
+ attention_mask: Optional[torch.Tensor] = None,
1041
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1042
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
1043
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
1044
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
1045
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
1046
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1047
+ return_dict: bool = True,
1048
+ ) -> Union[UNet2DConditionOutput, Tuple]:
1049
+ r"""
1050
+ The [`UNet2DConditionModel`] forward method.
1051
+
1052
+ Args:
1053
+ sample (`torch.FloatTensor`):
1054
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
1055
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
1056
+ encoder_hidden_states (`torch.FloatTensor`):
1057
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
1058
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
1059
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
1060
+ timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
1061
+ Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
1062
+ through the `self.time_embedding` layer to obtain the timestep embeddings.
1063
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
1064
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
1065
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
1066
+ negative values to the attention scores corresponding to "discard" tokens.
1067
+ cross_attention_kwargs (`dict`, *optional*):
1068
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1069
+ `self.processor` in
1070
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1071
+ added_cond_kwargs: (`dict`, *optional*):
1072
+ A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
1073
+ are passed along to the UNet blocks.
1074
+ down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
1075
+ A tuple of tensors that if specified are added to the residuals of down unet blocks.
1076
+ mid_block_additional_residual: (`torch.Tensor`, *optional*):
1077
+ A tensor that if specified is added to the residual of the middle unet block.
1078
+ down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
1079
+ additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
1080
+ encoder_attention_mask (`torch.Tensor`):
1081
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
1082
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
1083
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
1084
+ return_dict (`bool`, *optional*, defaults to `True`):
1085
+ Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
1086
+ tuple.
1087
+
1088
+ Returns:
1089
+ [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
1090
+ If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned,
1091
+ otherwise a `tuple` is returned where the first element is the sample tensor.
1092
+ """
1093
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
1094
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
1095
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
1096
+ # on the fly if necessary.
1097
+ default_overall_up_factor = 2**self.num_upsamplers
1098
+
1099
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
1100
+ forward_upsample_size = False
1101
+ upsample_size = None
1102
+
1103
+ for dim in sample.shape[-2:]:
1104
+ if dim % default_overall_up_factor != 0:
1105
+ # Forward upsample size to force interpolation output size.
1106
+ forward_upsample_size = True
1107
+ break
1108
+
1109
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
1110
+ # expects mask of shape:
1111
+ # [batch, key_tokens]
1112
+ # adds singleton query_tokens dimension:
1113
+ # [batch, 1, key_tokens]
1114
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
1115
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
1116
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
1117
+ if attention_mask is not None:
1118
+ # assume that mask is expressed as:
1119
+ # (1 = keep, 0 = discard)
1120
+ # convert mask into a bias that can be added to attention scores:
1121
+ # (keep = +0, discard = -10000.0)
1122
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
1123
+ attention_mask = attention_mask.unsqueeze(1)
1124
+
1125
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
1126
+ if encoder_attention_mask is not None:
1127
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
1128
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
1129
+
1130
+ # 0. center input if necessary
1131
+ if self.config.center_input_sample:
1132
+ sample = 2 * sample - 1.0
1133
+
1134
+ # 1. time
1135
+ t_emb = self.get_time_embed(sample=sample, timestep=timestep)
1136
+ emb = self.time_embedding(t_emb, timestep_cond)
1137
+ aug_emb = None
1138
+ class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
1139
+ if class_emb is not None:
1140
+ if self.config.class_embeddings_concat:
1141
+ emb = torch.cat([emb, class_emb], dim=-1)
1142
+ else:
1143
+ emb = emb + class_emb
1144
+
1145
+ aug_emb = self.get_aug_embed(
1146
+ emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
1147
+ )
1148
+ if self.config.addition_embed_type == "image_hint":
1149
+ aug_emb, hint = aug_emb
1150
+ sample = torch.cat([sample, hint], dim=1)
1151
+
1152
+ emb = emb + aug_emb if aug_emb is not None else emb
1153
+
1154
+ if self.time_embed_act is not None:
1155
+ emb = self.time_embed_act(emb)
1156
+
1157
+ encoder_hidden_states = self.process_encoder_hidden_states(
1158
+ encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
1159
+ )
1160
+
1161
+ # 2. pre-process
1162
+ sample = self.conv_in(sample)
1163
+
1164
+ # 2.5 GLIGEN position net
1165
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
1166
+ cross_attention_kwargs = cross_attention_kwargs.copy()
1167
+ gligen_args = cross_attention_kwargs.pop("gligen")
1168
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
1169
+
1170
+ # 3. down
1171
+ # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated
1172
+ # to the internal blocks and will raise deprecation warnings. this will be confusing for our users.
1173
+ if cross_attention_kwargs is not None:
1174
+ cross_attention_kwargs = cross_attention_kwargs.copy()
1175
+ lora_scale = cross_attention_kwargs.pop("scale", 1.0)
1176
+ else:
1177
+ lora_scale = 1.0
1178
+
1179
+ if USE_PEFT_BACKEND:
1180
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
1181
+ scale_lora_layers(self, lora_scale)
1182
+
1183
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
1184
+ # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
1185
+ is_adapter = down_intrablock_additional_residuals is not None
1186
+ # maintain backward compatibility for legacy usage, where
1187
+ # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
1188
+ # but can only use one or the other
1189
+ if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
1190
+ deprecate(
1191
+ "T2I should not use down_block_additional_residuals",
1192
+ "1.3.0",
1193
+ "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
1194
+ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
1195
+ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
1196
+ standard_warn=False,
1197
+ )
1198
+ down_intrablock_additional_residuals = down_block_additional_residuals
1199
+ is_adapter = True
1200
+
1201
+ down_block_res_samples = (sample,)
1202
+ for downsample_block in self.down_blocks:
1203
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
1204
+ # For t2i-adapter CrossAttnDownBlock2D
1205
+ additional_residuals = {}
1206
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1207
+ additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
1208
+
1209
+ sample, res_samples = downsample_block(
1210
+ hidden_states=sample,
1211
+ temb=emb,
1212
+ encoder_hidden_states=encoder_hidden_states,
1213
+ attention_mask=attention_mask,
1214
+ cross_attention_kwargs=cross_attention_kwargs,
1215
+ encoder_attention_mask=encoder_attention_mask,
1216
+ **additional_residuals,
1217
+ )
1218
+ else:
1219
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
1220
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1221
+ sample += down_intrablock_additional_residuals.pop(0)
1222
+
1223
+ down_block_res_samples += res_samples
1224
+
1225
+ if is_controlnet:
1226
+ new_down_block_res_samples = ()
1227
+
1228
+ for down_block_res_sample, down_block_additional_residual in zip(
1229
+ down_block_res_samples, down_block_additional_residuals
1230
+ ):
1231
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
1232
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
1233
+
1234
+ down_block_res_samples = new_down_block_res_samples
1235
+
1236
+ # 4. mid
1237
+ if self.mid_block is not None:
1238
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
1239
+ sample = self.mid_block(
1240
+ sample,
1241
+ emb,
1242
+ encoder_hidden_states=encoder_hidden_states,
1243
+ attention_mask=attention_mask,
1244
+ cross_attention_kwargs=cross_attention_kwargs,
1245
+ encoder_attention_mask=encoder_attention_mask,
1246
+ )
1247
+ else:
1248
+ sample = self.mid_block(sample, emb)
1249
+
1250
+ # To support T2I-Adapter-XL
1251
+ if (
1252
+ is_adapter
1253
+ and len(down_intrablock_additional_residuals) > 0
1254
+ and sample.shape == down_intrablock_additional_residuals[0].shape
1255
+ ):
1256
+ sample += down_intrablock_additional_residuals.pop(0)
1257
+
1258
+ if is_controlnet:
1259
+ sample = sample + mid_block_additional_residual
1260
+
1261
+ # 5. up
1262
+ for i, upsample_block in enumerate(self.up_blocks):
1263
+ is_final_block = i == len(self.up_blocks) - 1
1264
+
1265
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1266
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1267
+
1268
+ # if we have not reached the final block and need to forward the
1269
+ # upsample size, we do it here
1270
+ if not is_final_block and forward_upsample_size:
1271
+ upsample_size = down_block_res_samples[-1].shape[2:]
1272
+
1273
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1274
+ sample = upsample_block(
1275
+ hidden_states=sample,
1276
+ temb=emb,
1277
+ res_hidden_states_tuple=res_samples,
1278
+ encoder_hidden_states=encoder_hidden_states,
1279
+ cross_attention_kwargs=cross_attention_kwargs,
1280
+ upsample_size=upsample_size,
1281
+ attention_mask=attention_mask,
1282
+ encoder_attention_mask=encoder_attention_mask,
1283
+ )
1284
+ else:
1285
+ sample = upsample_block(
1286
+ hidden_states=sample,
1287
+ temb=emb,
1288
+ res_hidden_states_tuple=res_samples,
1289
+ upsample_size=upsample_size,
1290
+ )
1291
+
1292
+ # 6. post-process
1293
+ if self.conv_norm_out:
1294
+ sample = self.conv_norm_out(sample)
1295
+ sample = self.conv_act(sample)
1296
+ sample = self.conv_out(sample)
1297
+
1298
+ if USE_PEFT_BACKEND:
1299
+ # remove `lora_scale` from each PEFT layer
1300
+ unscale_lora_layers(self, lora_scale)
1301
+
1302
+ if not return_dict:
1303
+ return (sample,)
1304
+
1305
+ return UNet2DConditionOutput(sample=sample)
src/point_network.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import Tuple
5
+ from diffusers.models.modeling_utils import ModelMixin
6
+ class PointNet(ModelMixin):
7
+ def __init__(
8
+ self,
9
+ conditioning_channels: int = 1,
10
+ out_channels: Tuple[int] = (320, 640, 1280, 1280),
11
+ downsamples: Tuple[int] = (6, 2, 2, 2)
12
+ ):
13
+ super(PointNet, self).__init__()
14
+
15
+ self.blocks = nn.ModuleList()
16
+ current_channels = conditioning_channels
17
+
18
+ # 构造卷积块
19
+ for out_channel, downsample in zip(out_channels, downsamples):
20
+ layers = []
21
+ for _ in range(downsample // 2):
22
+ layers.append(nn.Conv2d(in_channels=current_channels, out_channels=out_channel, kernel_size=3, stride=2, padding=1))
23
+ layers.append(nn.SiLU())
24
+ current_channels = out_channel
25
+ self.blocks.append(nn.Sequential(*layers))
26
+
27
+ def forward(self, x):
28
+ embeddings = []
29
+ embedding = x
30
+ for block in self.blocks:
31
+ embedding = block(embedding)
32
+ B, C, H, W = embedding.shape
33
+ embeddings.append(embedding.view(B, C, H * W).transpose(1, 2))
34
+ # embeddings.append(embedding)
35
+ return embeddings
36
+
37
+ if __name__ == "__main__":
38
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
39
+ print(f'Using device: {device}')
40
+ model = PointNet().to(device)
41
+
42
+ dummy_input = torch.randn(1, 1, 288, 512).to(device) # Batch size = 1, Channels = 1, Height = 288, Width = 512
43
+ embeddings = model(dummy_input)
44
+ for i, embedding in enumerate(embeddings):
45
+ print(f"Output at layer {i + 1}:", embedding.shape)
test_cases/hz0.png ADDED
test_cases/hz01_0.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:810f8089bdc9833596ef163a372128de1b5e9e7fce29e17779e2d0539119d10e
3
+ size 262272
test_cases/hz01_0.png ADDED
test_cases/hz01_1.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:72334ed6a7ab06e43b554cb3ce98830e012270762aa64fecfc1e7eaa49037019
3
+ size 262272
test_cases/hz01_1.png ADDED
test_cases/hz1.png ADDED
test_cases/more_cases/az0.png ADDED
test_cases/more_cases/az1.JPG ADDED
test_cases/more_cases/hi0.png ADDED
test_cases/more_cases/hi1.jpg ADDED
test_cases/more_cases/hz0_lineart.png ADDED
test_cases/more_cases/kn0.jpg ADDED
test_cases/more_cases/kn1.jpg ADDED
test_cases/more_cases/rk0.jpg ADDED
test_cases/more_cases/rk1.jpg ADDED
utils/image_util.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib
2
+ import numpy as np
3
+ import torch
4
+ from PIL import Image
5
+
6
+ def resize_max_res(img: Image.Image, max_edge_resolution: int) -> Image.Image:
7
+ """
8
+ Resize image to limit maximum edge length while keeping aspect ratio.
9
+ Args:
10
+ img (`Image.Image`):
11
+ Image to be resized.
12
+ max_edge_resolution (`int`):
13
+ Maximum edge length (pixel).
14
+ Returns:
15
+ `Image.Image`: Resized image.
16
+ """
17
+
18
+ original_width, original_height = img.size
19
+
20
+ downscale_factor = min(
21
+ max_edge_resolution / original_width, max_edge_resolution / original_height
22
+ )
23
+
24
+ new_width = int(original_width * downscale_factor)
25
+ new_height = int(original_height * downscale_factor)
26
+
27
+ resized_img = img.resize((new_width, new_height))
28
+ return resized_img
29
+
30
+ def chw2hwc(chw):
31
+ assert 3 == len(chw.shape)
32
+ if isinstance(chw, torch.Tensor):
33
+ hwc = torch.permute(chw, (1, 2, 0))
34
+ elif isinstance(chw, np.ndarray):
35
+ hwc = np.moveaxis(chw, 0, -1)
36
+ return hwc