salso commited on
Commit
5432315
·
verified ·
1 Parent(s): 80cf68b

Upload 7 files

Browse files
flux/__init__.py ADDED
File without changes
flux/block.py ADDED
@@ -0,0 +1,461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Recycled from Ominicontrol and modified to accept an extra condition.
2
+ # While Zenctrl pursued a similar idea, it diverged structurally.
3
+ # We appreciate the clarity of Omini's implementation and decided to align with it.
4
+
5
+ import torch
6
+ from typing import List, Union, Optional, Dict, Any, Callable
7
+ from diffusers.models.attention_processor import Attention, F
8
+ from .lora_controller import enable_lora
9
+ from diffusers.models.embeddings import apply_rotary_emb
10
+
11
+ def attn_forward(
12
+ attn: Attention,
13
+ hidden_states: torch.FloatTensor,
14
+ encoder_hidden_states: torch.FloatTensor = None,
15
+ condition_latents: torch.FloatTensor = None,
16
+ extra_condition_latents: torch.FloatTensor = None,
17
+ attention_mask: Optional[torch.FloatTensor] = None,
18
+ image_rotary_emb: Optional[torch.Tensor] = None,
19
+ cond_rotary_emb: Optional[torch.Tensor] = None,
20
+ extra_cond_rotary_emb: Optional[torch.Tensor] = None,
21
+ model_config: Optional[Dict[str, Any]] = {},
22
+ ) -> torch.FloatTensor:
23
+ batch_size, _, _ = (
24
+ hidden_states.shape
25
+ if encoder_hidden_states is None
26
+ else encoder_hidden_states.shape
27
+ )
28
+
29
+ with enable_lora(
30
+ (attn.to_q, attn.to_k, attn.to_v), model_config.get("latent_lora", False)
31
+ ):
32
+ # `sample` projections.
33
+ query = attn.to_q(hidden_states)
34
+ key = attn.to_k(hidden_states)
35
+ value = attn.to_v(hidden_states)
36
+
37
+ inner_dim = key.shape[-1]
38
+ head_dim = inner_dim // attn.heads
39
+
40
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
41
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
42
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
43
+
44
+ if attn.norm_q is not None:
45
+ query = attn.norm_q(query)
46
+ if attn.norm_k is not None:
47
+ key = attn.norm_k(key)
48
+
49
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
50
+ if encoder_hidden_states is not None:
51
+ # `context` projections.
52
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
53
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
54
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
55
+
56
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
57
+ batch_size, -1, attn.heads, head_dim
58
+ ).transpose(1, 2)
59
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
60
+ batch_size, -1, attn.heads, head_dim
61
+ ).transpose(1, 2)
62
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
63
+ batch_size, -1, attn.heads, head_dim
64
+ ).transpose(1, 2)
65
+
66
+ if attn.norm_added_q is not None:
67
+ encoder_hidden_states_query_proj = attn.norm_added_q(
68
+ encoder_hidden_states_query_proj
69
+ )
70
+ if attn.norm_added_k is not None:
71
+ encoder_hidden_states_key_proj = attn.norm_added_k(
72
+ encoder_hidden_states_key_proj
73
+ )
74
+
75
+ # attention
76
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
77
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
78
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
79
+
80
+ if image_rotary_emb is not None:
81
+
82
+
83
+ query = apply_rotary_emb(query, image_rotary_emb)
84
+ key = apply_rotary_emb(key, image_rotary_emb)
85
+
86
+ if condition_latents is not None:
87
+ cond_query = attn.to_q(condition_latents)
88
+ cond_key = attn.to_k(condition_latents)
89
+ cond_value = attn.to_v(condition_latents)
90
+
91
+ cond_query = cond_query.view(batch_size, -1, attn.heads, head_dim).transpose(
92
+ 1, 2
93
+ )
94
+ cond_key = cond_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
95
+ cond_value = cond_value.view(batch_size, -1, attn.heads, head_dim).transpose(
96
+ 1, 2
97
+ )
98
+ if attn.norm_q is not None:
99
+ cond_query = attn.norm_q(cond_query)
100
+ if attn.norm_k is not None:
101
+ cond_key = attn.norm_k(cond_key)
102
+
103
+ #extra condition
104
+ if extra_condition_latents is not None:
105
+ extra_cond_query = attn.to_q(extra_condition_latents)
106
+ extra_cond_key = attn.to_k(extra_condition_latents)
107
+ extra_cond_value = attn.to_v(extra_condition_latents)
108
+
109
+ extra_cond_query = extra_cond_query.view(batch_size, -1, attn.heads, head_dim).transpose(
110
+ 1, 2
111
+ )
112
+ extra_cond_key = extra_cond_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
113
+ extra_cond_value = extra_cond_value.view(batch_size, -1, attn.heads, head_dim).transpose(
114
+ 1, 2
115
+ )
116
+ if attn.norm_q is not None:
117
+ extra_cond_query = attn.norm_q(extra_cond_query)
118
+ if attn.norm_k is not None:
119
+ extra_cond_key = attn.norm_k(extra_cond_key)
120
+
121
+
122
+ if extra_cond_rotary_emb is not None:
123
+ extra_cond_query = apply_rotary_emb(extra_cond_query, extra_cond_rotary_emb)
124
+ extra_cond_key = apply_rotary_emb(extra_cond_key, extra_cond_rotary_emb)
125
+
126
+ if cond_rotary_emb is not None:
127
+ cond_query = apply_rotary_emb(cond_query, cond_rotary_emb)
128
+ cond_key = apply_rotary_emb(cond_key, cond_rotary_emb)
129
+
130
+ if condition_latents is not None:
131
+ if extra_condition_latents is not None:
132
+
133
+ query = torch.cat([query, cond_query, extra_cond_query], dim=2)
134
+ key = torch.cat([key, cond_key, extra_cond_key], dim=2)
135
+ value = torch.cat([value, cond_value, extra_cond_value], dim=2)
136
+ else:
137
+ query = torch.cat([query, cond_query], dim=2)
138
+ key = torch.cat([key, cond_key], dim=2)
139
+ value = torch.cat([value, cond_value], dim=2)
140
+ print("concat Omini latents: ", query.shape, key.shape, value.shape)
141
+
142
+
143
+ if not model_config.get("union_cond_attn", True):
144
+
145
+ attention_mask = torch.ones(
146
+ query.shape[2], key.shape[2], device=query.device, dtype=torch.bool
147
+ )
148
+ condition_n = cond_query.shape[2]
149
+ attention_mask[-condition_n:, :-condition_n] = False
150
+ attention_mask[:-condition_n, -condition_n:] = False
151
+ elif model_config.get("independent_condition", False):
152
+ attention_mask = torch.ones(
153
+ query.shape[2], key.shape[2], device=query.device, dtype=torch.bool
154
+ )
155
+ condition_n = cond_query.shape[2]
156
+ attention_mask[-condition_n:, :-condition_n] = False
157
+
158
+ if hasattr(attn, "c_factor"):
159
+ attention_mask = torch.zeros(
160
+ query.shape[2], key.shape[2], device=query.device, dtype=query.dtype
161
+ )
162
+ condition_n = cond_query.shape[2]
163
+ condition_e = extra_cond_query.shape[2]
164
+ bias = torch.log(attn.c_factor[0])
165
+ attention_mask[-condition_n-condition_e:-condition_e, :-condition_n-condition_e] = bias
166
+ attention_mask[:-condition_n-condition_e, -condition_n-condition_e:-condition_e] = bias
167
+
168
+ bias = torch.log(attn.c_factor[1])
169
+ attention_mask[-condition_e:, :-condition_n-condition_e] = bias
170
+ attention_mask[:-condition_n-condition_e, -condition_e:] = bias
171
+
172
+ hidden_states = F.scaled_dot_product_attention(
173
+ query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask
174
+ )
175
+ hidden_states = hidden_states.transpose(1, 2).reshape(
176
+ batch_size, -1, attn.heads * head_dim
177
+ )
178
+ hidden_states = hidden_states.to(query.dtype)
179
+
180
+ if encoder_hidden_states is not None:
181
+ if condition_latents is not None:
182
+ if extra_condition_latents is not None:
183
+ encoder_hidden_states, hidden_states, condition_latents, extra_condition_latents = (
184
+ hidden_states[:, : encoder_hidden_states.shape[1]],
185
+ hidden_states[
186
+ :, encoder_hidden_states.shape[1] : -condition_latents.shape[1]*2
187
+ ],
188
+ hidden_states[:, -condition_latents.shape[1]*2 :-condition_latents.shape[1]],
189
+ hidden_states[:, -condition_latents.shape[1] :], #extra condition latents
190
+ )
191
+ else:
192
+ encoder_hidden_states, hidden_states, condition_latents = (
193
+ hidden_states[:, : encoder_hidden_states.shape[1]],
194
+ hidden_states[
195
+ :, encoder_hidden_states.shape[1] : -condition_latents.shape[1]
196
+ ],
197
+ hidden_states[:, -condition_latents.shape[1] :]
198
+ )
199
+ else:
200
+ encoder_hidden_states, hidden_states = (
201
+ hidden_states[:, : encoder_hidden_states.shape[1]],
202
+ hidden_states[:, encoder_hidden_states.shape[1] :],
203
+ )
204
+
205
+ with enable_lora((attn.to_out[0],), model_config.get("latent_lora", False)):
206
+ # linear proj
207
+ hidden_states = attn.to_out[0](hidden_states)
208
+ # dropout
209
+ hidden_states = attn.to_out[1](hidden_states)
210
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
211
+
212
+ if condition_latents is not None:
213
+ condition_latents = attn.to_out[0](condition_latents)
214
+ condition_latents = attn.to_out[1](condition_latents)
215
+
216
+ if extra_condition_latents is not None:
217
+ extra_condition_latents = attn.to_out[0](extra_condition_latents)
218
+ extra_condition_latents = attn.to_out[1](extra_condition_latents)
219
+
220
+
221
+ return (
222
+ # (hidden_states, encoder_hidden_states, condition_latents, extra_condition_latents)
223
+ (hidden_states, encoder_hidden_states, condition_latents, extra_condition_latents)
224
+ if condition_latents is not None
225
+ else (hidden_states, encoder_hidden_states)
226
+ )
227
+ elif condition_latents is not None:
228
+ # if there are condition_latents, we need to separate the hidden_states and the condition_latents
229
+ if extra_condition_latents is not None:
230
+ hidden_states, condition_latents, extra_condition_latents = (
231
+ hidden_states[:, : -condition_latents.shape[1]*2],
232
+ hidden_states[:, -condition_latents.shape[1]*2 :-condition_latents.shape[1]],
233
+ hidden_states[:, -condition_latents.shape[1] :],
234
+ )
235
+ else:
236
+ hidden_states, condition_latents = (
237
+ hidden_states[:, : -condition_latents.shape[1]],
238
+ hidden_states[:, -condition_latents.shape[1] :],
239
+ )
240
+ return hidden_states, condition_latents, extra_condition_latents
241
+ else:
242
+ return hidden_states
243
+
244
+
245
+ def block_forward(
246
+ self,
247
+ hidden_states: torch.FloatTensor,
248
+ encoder_hidden_states: torch.FloatTensor,
249
+ condition_latents: torch.FloatTensor,
250
+ extra_condition_latents: torch.FloatTensor,
251
+ temb: torch.FloatTensor,
252
+ cond_temb: torch.FloatTensor,
253
+ extra_cond_temb: torch.FloatTensor,
254
+ cond_rotary_emb=None,
255
+ extra_cond_rotary_emb=None,
256
+ image_rotary_emb=None,
257
+ model_config: Optional[Dict[str, Any]] = {},
258
+ ):
259
+ use_cond = condition_latents is not None
260
+
261
+ use_extra_cond = extra_condition_latents is not None
262
+ with enable_lora((self.norm1.linear,), model_config.get("latent_lora", False)):
263
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
264
+ hidden_states, emb=temb
265
+ )
266
+
267
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (
268
+ self.norm1_context(encoder_hidden_states, emb=temb)
269
+ )
270
+
271
+ if use_cond:
272
+ (
273
+ norm_condition_latents,
274
+ cond_gate_msa,
275
+ cond_shift_mlp,
276
+ cond_scale_mlp,
277
+ cond_gate_mlp,
278
+ ) = self.norm1(condition_latents, emb=cond_temb)
279
+ (
280
+ norm_extra_condition_latents,
281
+ extra_cond_gate_msa,
282
+ extra_cond_shift_mlp,
283
+ extra_cond_scale_mlp,
284
+ extra_cond_gate_mlp,
285
+ ) = self.norm1(extra_condition_latents, emb=extra_cond_temb)
286
+
287
+ # Attention.
288
+ result = attn_forward(
289
+ self.attn,
290
+ model_config=model_config,
291
+ hidden_states=norm_hidden_states,
292
+ encoder_hidden_states=norm_encoder_hidden_states,
293
+ condition_latents=norm_condition_latents if use_cond else None,
294
+ extra_condition_latents=norm_extra_condition_latents if use_cond else None,
295
+ image_rotary_emb=image_rotary_emb,
296
+ cond_rotary_emb=cond_rotary_emb if use_cond else None,
297
+ extra_cond_rotary_emb=extra_cond_rotary_emb if use_extra_cond else None,
298
+ )
299
+ # print("in self block: ", result.shape)
300
+ attn_output, context_attn_output = result[:2]
301
+ cond_attn_output = result[2] if use_cond else None
302
+ extra_condition_output = result[3]
303
+
304
+ # Process attention outputs for the `hidden_states`.
305
+ # 1. hidden_states
306
+ attn_output = gate_msa.unsqueeze(1) * attn_output
307
+ hidden_states = hidden_states + attn_output
308
+ # 2. encoder_hidden_states
309
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
310
+
311
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
312
+ # 3. condition_latents
313
+ if use_cond:
314
+ cond_attn_output = cond_gate_msa.unsqueeze(1) * cond_attn_output
315
+ condition_latents = condition_latents + cond_attn_output
316
+ #need to make new condition_extra and add extra_condition_output
317
+ if use_extra_cond:
318
+ extra_condition_output = extra_cond_gate_msa.unsqueeze(1) * extra_condition_output
319
+ extra_condition_latents = extra_condition_latents + extra_condition_output
320
+
321
+ if model_config.get("add_cond_attn", False):
322
+ hidden_states += cond_attn_output
323
+ hidden_states += extra_condition_output
324
+
325
+
326
+ # LayerNorm + MLP.
327
+ # 1. hidden_states
328
+ norm_hidden_states = self.norm2(hidden_states)
329
+ norm_hidden_states = (
330
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
331
+ )
332
+ # 2. encoder_hidden_states
333
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
334
+ norm_encoder_hidden_states = (
335
+ norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
336
+ )
337
+ # 3. condition_latents
338
+ if use_cond:
339
+ norm_condition_latents = self.norm2(condition_latents)
340
+ norm_condition_latents = (
341
+ norm_condition_latents * (1 + cond_scale_mlp[:, None])
342
+ + cond_shift_mlp[:, None]
343
+ )
344
+
345
+ if use_extra_cond:
346
+ #added conditions
347
+ extra_norm_condition_latents = self.norm2(extra_condition_latents)
348
+ extra_norm_condition_latents = (
349
+ extra_norm_condition_latents * (1 + extra_cond_scale_mlp[:, None])
350
+ + extra_cond_shift_mlp[:, None]
351
+ )
352
+
353
+ # Feed-forward.
354
+ with enable_lora((self.ff.net[2],), model_config.get("latent_lora", False)):
355
+ # 1. hidden_states
356
+ ff_output = self.ff(norm_hidden_states)
357
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
358
+ # 2. encoder_hidden_states
359
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
360
+ context_ff_output = c_gate_mlp.unsqueeze(1) * context_ff_output
361
+ # 3. condition_latents
362
+ if use_cond:
363
+ cond_ff_output = self.ff(norm_condition_latents)
364
+ cond_ff_output = cond_gate_mlp.unsqueeze(1) * cond_ff_output
365
+
366
+ if use_extra_cond:
367
+ extra_cond_ff_output = self.ff(extra_norm_condition_latents)
368
+ extra_cond_ff_output = extra_cond_gate_mlp.unsqueeze(1) * extra_cond_ff_output
369
+
370
+ # Process feed-forward outputs.
371
+ hidden_states = hidden_states + ff_output
372
+ encoder_hidden_states = encoder_hidden_states + context_ff_output
373
+ if use_cond:
374
+ condition_latents = condition_latents + cond_ff_output
375
+ if use_extra_cond:
376
+ extra_condition_latents = extra_condition_latents + extra_cond_ff_output
377
+
378
+ # Clip to avoid overflow.
379
+ if encoder_hidden_states.dtype == torch.float16:
380
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
381
+
382
+ return encoder_hidden_states, hidden_states, condition_latents, extra_condition_latents if use_cond else None
383
+
384
+
385
+ def single_block_forward(
386
+ self,
387
+ hidden_states: torch.FloatTensor,
388
+ temb: torch.FloatTensor,
389
+ image_rotary_emb=None,
390
+ condition_latents: torch.FloatTensor = None,
391
+ extra_condition_latents: torch.FloatTensor = None,
392
+ cond_temb: torch.FloatTensor = None,
393
+ extra_cond_temb: torch.FloatTensor = None,
394
+ cond_rotary_emb=None,
395
+ extra_cond_rotary_emb=None,
396
+ model_config: Optional[Dict[str, Any]] = {},
397
+ ):
398
+
399
+ using_cond = condition_latents is not None
400
+ using_extra_cond = extra_condition_latents is not None
401
+ residual = hidden_states
402
+ with enable_lora(
403
+ (
404
+ self.norm.linear,
405
+ self.proj_mlp,
406
+ ),
407
+ model_config.get("latent_lora", False),
408
+ ):
409
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
410
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
411
+ if using_cond:
412
+ residual_cond = condition_latents
413
+ norm_condition_latents, cond_gate = self.norm(condition_latents, emb=cond_temb)
414
+ mlp_cond_hidden_states = self.act_mlp(self.proj_mlp(norm_condition_latents))
415
+
416
+ if using_extra_cond:
417
+ extra_residual_cond = extra_condition_latents
418
+ extra_norm_condition_latents, extra_cond_gate = self.norm(extra_condition_latents, emb=extra_cond_temb)
419
+ extra_mlp_cond_hidden_states = self.act_mlp(self.proj_mlp(extra_norm_condition_latents))
420
+
421
+ attn_output = attn_forward(
422
+ self.attn,
423
+ model_config=model_config,
424
+ hidden_states=norm_hidden_states,
425
+ image_rotary_emb=image_rotary_emb,
426
+ **(
427
+ {
428
+ "condition_latents": norm_condition_latents,
429
+ "cond_rotary_emb": cond_rotary_emb if using_cond else None,
430
+ "extra_condition_latents": extra_norm_condition_latents if using_cond else None,
431
+ "extra_cond_rotary_emb": extra_cond_rotary_emb if using_cond else None,
432
+ }
433
+ if using_cond
434
+ else {}
435
+ ),
436
+ )
437
+
438
+ if using_cond:
439
+ attn_output, cond_attn_output, extra_cond_attn_output = attn_output
440
+
441
+
442
+ with enable_lora((self.proj_out,), model_config.get("latent_lora", False)):
443
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
444
+ gate = gate.unsqueeze(1)
445
+ hidden_states = gate * self.proj_out(hidden_states)
446
+ hidden_states = residual + hidden_states
447
+ if using_cond:
448
+ condition_latents = torch.cat([cond_attn_output, mlp_cond_hidden_states], dim=2)
449
+ cond_gate = cond_gate.unsqueeze(1)
450
+ condition_latents = cond_gate * self.proj_out(condition_latents)
451
+ condition_latents = residual_cond + condition_latents
452
+
453
+ extra_condition_latents = torch.cat([extra_cond_attn_output, extra_mlp_cond_hidden_states], dim=2)
454
+ extra_cond_gate = extra_cond_gate.unsqueeze(1)
455
+ extra_condition_latents = extra_cond_gate * self.proj_out(extra_condition_latents)
456
+ extra_condition_latents = extra_residual_cond + extra_condition_latents
457
+
458
+ if hidden_states.dtype == torch.float16:
459
+ hidden_states = hidden_states.clip(-65504, 65504)
460
+
461
+ return hidden_states if not using_cond else (hidden_states, condition_latents, extra_condition_latents)
flux/condition.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Recycled from Ominicontrol and modified to accept an extra condition.
2
+ # While Zenctrl pursued a similar idea, it diverged structurally.
3
+ # We appreciate the clarity of Omini's implementation and decided to align with it.
4
+
5
+ import torch
6
+ from typing import Optional, Union, List, Tuple
7
+ from diffusers.pipelines import FluxPipeline
8
+ from PIL import Image, ImageFilter
9
+ import numpy as np
10
+ import cv2
11
+
12
+ # from pipeline_tools import encode_images
13
+ from .pipeline_tools import encode_images
14
+
15
+ condition_dict = {
16
+ "subject": 1,
17
+ "sr": 2,
18
+ "cot": 3,
19
+ }
20
+
21
+
22
+ class Condition(object):
23
+ def __init__(
24
+ self,
25
+ condition_type: str,
26
+ raw_img: Union[Image.Image, torch.Tensor] = None,
27
+ condition: Union[Image.Image, torch.Tensor] = None,
28
+ position_delta=None,
29
+ ) -> None:
30
+ self.condition_type = condition_type
31
+ assert raw_img is not None or condition is not None
32
+ if raw_img is not None:
33
+ self.condition = self.get_condition(condition_type, raw_img)
34
+ else:
35
+ self.condition = condition
36
+ self.position_delta = position_delta
37
+
38
+
39
+ def get_condition(
40
+ self, condition_type: str, raw_img: Union[Image.Image, torch.Tensor]
41
+ ) -> Union[Image.Image, torch.Tensor]:
42
+ """
43
+ Returns the condition image.
44
+ """
45
+ if condition_type == "subject":
46
+ return raw_img
47
+ elif condition_type == "sr":
48
+ return raw_img
49
+ elif condition_type == "cot":
50
+ return raw_img.convert("RGB")
51
+ return self.condition
52
+
53
+
54
+ @property
55
+ def type_id(self) -> int:
56
+ """
57
+ Returns the type id of the condition.
58
+ """
59
+ return condition_dict[self.condition_type]
60
+
61
+ def encode(
62
+ self, pipe: FluxPipeline, empty: bool = False
63
+ ) -> Tuple[torch.Tensor, torch.Tensor, int]:
64
+ """
65
+ Encodes the condition into tokens, ids and type_id.
66
+ """
67
+ if self.condition_type in [
68
+ "subject",
69
+ "sr",
70
+ "cot"
71
+ ]:
72
+ if empty:
73
+ # make the condition black
74
+ e_condition = Image.new("RGB", self.condition.size, (0, 0, 0))
75
+ e_condition = e_condition.convert("RGB")
76
+ tokens, ids = encode_images(pipe, e_condition)
77
+ else:
78
+ tokens, ids = encode_images(pipe, self.condition)
79
+ else:
80
+ raise NotImplementedError(
81
+ f"Condition type {self.condition_type} not implemented"
82
+ )
83
+ if self.position_delta is None and self.condition_type == "subject":
84
+ self.position_delta = [0, -self.condition.size[0] // 16]
85
+ if self.position_delta is not None:
86
+ ids[:, 1] += self.position_delta[0]
87
+ ids[:, 2] += self.position_delta[1]
88
+ type_id = torch.ones_like(ids[:, :1]) * self.type_id
89
+ return tokens, ids, type_id
flux/generate.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Recycled from Ominicontrol and modified to accept an extra condition.
2
+ # While Zenctrl pursued a similar idea, it diverged structurally.
3
+ # We appreciate the clarity of Omini's implementation and decided to align with it.
4
+
5
+ import torch
6
+ import yaml, os
7
+ from diffusers.pipelines import FluxPipeline
8
+ from typing import List, Union, Optional, Dict, Any, Callable
9
+ from .transformer import tranformer_forward
10
+ from .condition import Condition
11
+
12
+
13
+ from diffusers.pipelines.flux.pipeline_flux import (
14
+ FluxPipelineOutput,
15
+ calculate_shift,
16
+ retrieve_timesteps,
17
+ np,
18
+ )
19
+
20
+
21
+ def get_config(config_path: str = None):
22
+ config_path = config_path or os.environ.get("XFL_CONFIG")
23
+ if not config_path:
24
+ return {}
25
+ with open(config_path, "r") as f:
26
+ config = yaml.safe_load(f)
27
+ return config
28
+
29
+
30
+ def prepare_params(
31
+ prompt: Union[str, List[str]] = None,
32
+ prompt_2: Optional[Union[str, List[str]]] = None,
33
+ height: Optional[int] = 512,
34
+ width: Optional[int] = 512,
35
+ num_inference_steps: int = 28,
36
+ timesteps: List[int] = None,
37
+ guidance_scale: float = 3.5,
38
+ num_images_per_prompt: Optional[int] = 1,
39
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
40
+ latents: Optional[torch.FloatTensor] = None,
41
+ prompt_embeds: Optional[torch.FloatTensor] = None,
42
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
43
+ output_type: Optional[str] = "pil",
44
+ return_dict: bool = True,
45
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
46
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
47
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
48
+ max_sequence_length: int = 512,
49
+ **kwargs: dict,
50
+ ):
51
+ return (
52
+ prompt,
53
+ prompt_2,
54
+ height,
55
+ width,
56
+ num_inference_steps,
57
+ timesteps,
58
+ guidance_scale,
59
+ num_images_per_prompt,
60
+ generator,
61
+ latents,
62
+ prompt_embeds,
63
+ pooled_prompt_embeds,
64
+ output_type,
65
+ return_dict,
66
+ joint_attention_kwargs,
67
+ callback_on_step_end,
68
+ callback_on_step_end_tensor_inputs,
69
+ max_sequence_length,
70
+ )
71
+
72
+
73
+ def seed_everything(seed: int = 42):
74
+ torch.backends.cudnn.deterministic = True
75
+ torch.manual_seed(seed)
76
+ np.random.seed(seed)
77
+
78
+
79
+ @torch.no_grad()
80
+ def generate(
81
+ pipeline: FluxPipeline,
82
+ conditions: List[Condition] = None,
83
+ config_path: str = None,
84
+ model_config: Optional[Dict[str, Any]] = {},
85
+ condition_scale: float = [1, 1],
86
+ default_lora: bool = False,
87
+ image_guidance_scale: float = 1.0,
88
+ **params: dict,
89
+ ):
90
+ model_config = model_config or get_config(config_path).get("model", {})
91
+ if condition_scale != [1,1]:
92
+ for name, module in pipeline.transformer.named_modules():
93
+ if not name.endswith(".attn"):
94
+ continue
95
+ module.c_factor = torch.tensor(condition_scale)
96
+
97
+ self = pipeline
98
+ (
99
+ prompt,
100
+ prompt_2,
101
+ height,
102
+ width,
103
+ num_inference_steps,
104
+ timesteps,
105
+ guidance_scale,
106
+ num_images_per_prompt,
107
+ generator,
108
+ latents,
109
+ prompt_embeds,
110
+ pooled_prompt_embeds,
111
+ output_type,
112
+ return_dict,
113
+ joint_attention_kwargs,
114
+ callback_on_step_end,
115
+ callback_on_step_end_tensor_inputs,
116
+ max_sequence_length,
117
+ ) = prepare_params(**params)
118
+
119
+ height = height or self.default_sample_size * self.vae_scale_factor
120
+ width = width or self.default_sample_size * self.vae_scale_factor
121
+
122
+ # 1. Check inputs. Raise error if not correct
123
+ self.check_inputs(
124
+ prompt,
125
+ prompt_2,
126
+ height,
127
+ width,
128
+ prompt_embeds=prompt_embeds,
129
+ pooled_prompt_embeds=pooled_prompt_embeds,
130
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
131
+ max_sequence_length=max_sequence_length,
132
+ )
133
+
134
+ self._guidance_scale = guidance_scale
135
+ self._joint_attention_kwargs = joint_attention_kwargs
136
+ self._interrupt = False
137
+
138
+ # 2. Define call parameters
139
+ if prompt is not None and isinstance(prompt, str):
140
+ batch_size = 1
141
+ elif prompt is not None and isinstance(prompt, list):
142
+ batch_size = len(prompt)
143
+ else:
144
+ batch_size = prompt_embeds.shape[0]
145
+
146
+ device = self._execution_device
147
+
148
+ lora_scale = (
149
+ self.joint_attention_kwargs.get("scale", None)
150
+ if self.joint_attention_kwargs is not None
151
+ else None
152
+ )
153
+ (
154
+ prompt_embeds,
155
+ pooled_prompt_embeds,
156
+ text_ids,
157
+ ) = self.encode_prompt(
158
+ prompt=prompt,
159
+ prompt_2=prompt_2,
160
+ prompt_embeds=prompt_embeds,
161
+ pooled_prompt_embeds=pooled_prompt_embeds,
162
+ device=device,
163
+ num_images_per_prompt=num_images_per_prompt,
164
+ max_sequence_length=max_sequence_length,
165
+ lora_scale=lora_scale,
166
+ )
167
+
168
+ # 4. Prepare latent variables
169
+ num_channels_latents = self.transformer.config.in_channels // 4
170
+ latents, latent_image_ids = self.prepare_latents(
171
+ batch_size * num_images_per_prompt,
172
+ num_channels_latents,
173
+ height,
174
+ width,
175
+ prompt_embeds.dtype,
176
+ device,
177
+ generator,
178
+ latents,
179
+ )
180
+
181
+ # 4.1. Prepare conditions
182
+ condition_latents, condition_ids, condition_type_ids = ([] for _ in range(3))
183
+ extra_condition_latents, extra_condition_ids, extra_condition_type_ids = ([] for _ in range(3))
184
+ use_condition = conditions is not None or []
185
+ if use_condition:
186
+ if not default_lora:
187
+ pipeline.set_adapters(conditions[1].condition_type)
188
+ # for condition in conditions:
189
+ tokens, ids, type_id = conditions[0].encode(self)
190
+ condition_latents.append(tokens) # [batch_size, token_n, token_dim]
191
+ condition_ids.append(ids) # [token_n, id_dim(3)]
192
+ condition_type_ids.append(type_id) # [token_n, 1]
193
+ condition_latents = torch.cat(condition_latents, dim=1)
194
+ condition_ids = torch.cat(condition_ids, dim=0)
195
+ condition_type_ids = torch.cat(condition_type_ids, dim=0)
196
+
197
+ tokens, ids, type_id = conditions[1].encode(self)
198
+ extra_condition_latents.append(tokens) # [batch_size, token_n, token_dim]
199
+ extra_condition_ids.append(ids) # [token_n, id_dim(3)]
200
+ extra_condition_type_ids.append(type_id) # [token_n, 1]
201
+ extra_condition_latents = torch.cat(extra_condition_latents, dim=1)
202
+ extra_condition_ids = torch.cat(extra_condition_ids, dim=0)
203
+ extra_condition_type_ids = torch.cat(extra_condition_type_ids, dim=0)
204
+
205
+ # 5. Prepare timesteps
206
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
207
+ image_seq_len = latents.shape[1]
208
+ mu = calculate_shift(
209
+ image_seq_len,
210
+ self.scheduler.config.base_image_seq_len,
211
+ self.scheduler.config.max_image_seq_len,
212
+ self.scheduler.config.base_shift,
213
+ self.scheduler.config.max_shift,
214
+ )
215
+ timesteps, num_inference_steps = retrieve_timesteps(
216
+ self.scheduler,
217
+ num_inference_steps,
218
+ device,
219
+ timesteps,
220
+ sigmas,
221
+ mu=mu,
222
+ )
223
+ num_warmup_steps = max(
224
+ len(timesteps) - num_inference_steps * self.scheduler.order, 0
225
+ )
226
+ self._num_timesteps = len(timesteps)
227
+
228
+ # 6. Denoising loop
229
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
230
+ for i, t in enumerate(timesteps):
231
+ if self.interrupt:
232
+ continue
233
+
234
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
235
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
236
+
237
+ # handle guidance
238
+ if self.transformer.config.guidance_embeds:
239
+ guidance = torch.tensor([guidance_scale], device=device)
240
+ guidance = guidance.expand(latents.shape[0])
241
+ else:
242
+ guidance = None
243
+ noise_pred = tranformer_forward(
244
+ self.transformer,
245
+ model_config=model_config,
246
+ # Inputs of the condition (new feature)
247
+ condition_latents=condition_latents if use_condition else None,
248
+ condition_ids=condition_ids if use_condition else None,
249
+ condition_type_ids=condition_type_ids if use_condition else None,
250
+ extra_condition_latents=extra_condition_latents if use_condition else None,
251
+ extra_condition_ids=extra_condition_ids if use_condition else None,
252
+ extra_condition_type_ids=extra_condition_type_ids if use_condition else None,
253
+ # Inputs to the original transformer
254
+ hidden_states=latents,
255
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
256
+ timestep=timestep / 1000,
257
+ guidance=guidance,
258
+ pooled_projections=pooled_prompt_embeds,
259
+ encoder_hidden_states=prompt_embeds,
260
+ txt_ids=text_ids,
261
+ img_ids=latent_image_ids,
262
+ joint_attention_kwargs=self.joint_attention_kwargs,
263
+ return_dict=False,
264
+ )[0]
265
+
266
+ if image_guidance_scale != 1.0:
267
+ uncondition_latents = conditions.encode(self, empty=True)[0]
268
+ unc_pred = tranformer_forward(
269
+ self.transformer,
270
+ model_config=model_config,
271
+ # Inputs of the condition (new feature)
272
+ condition_latents=uncondition_latents if use_condition else None,
273
+ condition_ids=condition_ids if use_condition else None,
274
+ condition_type_ids=condition_type_ids if use_condition else None,
275
+ # Inputs to the original transformer
276
+ hidden_states=latents,
277
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
278
+ timestep=timestep / 1000,
279
+ guidance=torch.ones_like(guidance),
280
+ pooled_projections=pooled_prompt_embeds,
281
+ encoder_hidden_states=prompt_embeds,
282
+ txt_ids=text_ids,
283
+ img_ids=latent_image_ids,
284
+ joint_attention_kwargs=self.joint_attention_kwargs,
285
+ return_dict=False,
286
+ )[0]
287
+
288
+ noise_pred = unc_pred + image_guidance_scale * (noise_pred - unc_pred)
289
+
290
+ # compute the previous noisy sample x_t -> x_t-1
291
+ latents_dtype = latents.dtype
292
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
293
+
294
+ if latents.dtype != latents_dtype:
295
+ if torch.backends.mps.is_available():
296
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
297
+ latents = latents.to(latents_dtype)
298
+
299
+ if callback_on_step_end is not None:
300
+ callback_kwargs = {}
301
+ for k in callback_on_step_end_tensor_inputs:
302
+ callback_kwargs[k] = locals()[k]
303
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
304
+
305
+ latents = callback_outputs.pop("latents", latents)
306
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
307
+
308
+ # call the callback, if provided
309
+ if i == len(timesteps) - 1 or (
310
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
311
+ ):
312
+ progress_bar.update()
313
+
314
+ if output_type == "latent":
315
+ image = latents
316
+
317
+ else:
318
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
319
+ latents = (
320
+ latents / self.vae.config.scaling_factor
321
+ ) + self.vae.config.shift_factor
322
+ image = self.vae.decode(latents, return_dict=False)[0]
323
+ image = self.image_processor.postprocess(image, output_type=output_type)
324
+
325
+ # Offload all models
326
+ self.maybe_free_model_hooks()
327
+
328
+ if condition_scale != [1,1]:
329
+ for name, module in pipeline.transformer.named_modules():
330
+ if not name.endswith(".attn"):
331
+ continue
332
+ del module.c_factor
333
+
334
+ if not return_dict:
335
+ return (image,)
336
+
337
+ return FluxPipelineOutput(images=image)
flux/lora_controller.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #As is from OminiControl
2
+ from peft.tuners.tuners_utils import BaseTunerLayer
3
+ from typing import List, Any, Optional, Type
4
+ from .condition import condition_dict
5
+
6
+
7
+ class enable_lora:
8
+ def __init__(self, lora_modules: List[BaseTunerLayer], activated: bool) -> None:
9
+ self.activated: bool = activated
10
+ if activated:
11
+ return
12
+ self.lora_modules: List[BaseTunerLayer] = [
13
+ each for each in lora_modules if isinstance(each, BaseTunerLayer)
14
+ ]
15
+ self.scales = [
16
+ {
17
+ active_adapter: lora_module.scaling[active_adapter]
18
+ for active_adapter in lora_module.active_adapters
19
+ }
20
+ for lora_module in self.lora_modules
21
+ ]
22
+
23
+ def __enter__(self) -> None:
24
+ if self.activated:
25
+ return
26
+
27
+ for lora_module in self.lora_modules:
28
+ if not isinstance(lora_module, BaseTunerLayer):
29
+ continue
30
+ for active_adapter in lora_module.active_adapters:
31
+ if (
32
+ active_adapter in condition_dict.keys()
33
+ or active_adapter == "default"
34
+ ):
35
+ lora_module.scaling[active_adapter] = 0.0
36
+
37
+ def __exit__(
38
+ self,
39
+ exc_type: Optional[Type[BaseException]],
40
+ exc_val: Optional[BaseException],
41
+ exc_tb: Optional[Any],
42
+ ) -> None:
43
+ if self.activated:
44
+ return
45
+ for i, lora_module in enumerate(self.lora_modules):
46
+ if not isinstance(lora_module, BaseTunerLayer):
47
+ continue
48
+ for active_adapter in lora_module.active_adapters:
49
+ lora_module.scaling[active_adapter] = self.scales[i][active_adapter]
50
+
51
+
52
+ class set_lora_scale:
53
+ def __init__(self, lora_modules: List[BaseTunerLayer], scale: float) -> None:
54
+ self.lora_modules: List[BaseTunerLayer] = [
55
+ each for each in lora_modules if isinstance(each, BaseTunerLayer)
56
+ ]
57
+ self.scales = [
58
+ {
59
+ active_adapter: lora_module.scaling[active_adapter]
60
+ for active_adapter in lora_module.active_adapters
61
+ }
62
+ for lora_module in self.lora_modules
63
+ ]
64
+ self.scale = scale
65
+
66
+ def __enter__(self) -> None:
67
+ for lora_module in self.lora_modules:
68
+ if not isinstance(lora_module, BaseTunerLayer):
69
+ continue
70
+ lora_module.scale_layer(self.scale)
71
+
72
+ def __exit__(
73
+ self,
74
+ exc_type: Optional[Type[BaseException]],
75
+ exc_val: Optional[BaseException],
76
+ exc_tb: Optional[Any],
77
+ ) -> None:
78
+ for i, lora_module in enumerate(self.lora_modules):
79
+ if not isinstance(lora_module, BaseTunerLayer):
80
+ continue
81
+ for active_adapter in lora_module.active_adapters:
82
+ lora_module.scaling[active_adapter] = self.scales[i][active_adapter]
flux/pipeline_tools.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #As is from OminiControl
2
+ from diffusers.pipelines import FluxPipeline
3
+ from diffusers.utils import logging
4
+ from diffusers.pipelines.flux.pipeline_flux import logger
5
+ from torch import Tensor
6
+
7
+
8
+ def encode_images(pipeline: FluxPipeline, images: Tensor):
9
+ images = pipeline.image_processor.preprocess(images)
10
+ images = images.to(pipeline.device).to(pipeline.dtype)
11
+ images = pipeline.vae.encode(images).latent_dist.sample()
12
+ images = (
13
+ images - pipeline.vae.config.shift_factor
14
+ ) * pipeline.vae.config.scaling_factor
15
+ images_tokens = pipeline._pack_latents(images, *images.shape)
16
+ images_ids = pipeline._prepare_latent_image_ids(
17
+ images.shape[0],
18
+ images.shape[2],
19
+ images.shape[3],
20
+ pipeline.device,
21
+ pipeline.dtype,
22
+ )
23
+ if images_tokens.shape[1] != images_ids.shape[0]:
24
+ images_ids = pipeline._prepare_latent_image_ids(
25
+ images.shape[0],
26
+ images.shape[2] // 2,
27
+ images.shape[3] // 2,
28
+ pipeline.device,
29
+ pipeline.dtype,
30
+ )
31
+ return images_tokens, images_ids
32
+
33
+
34
+ def prepare_text_input(pipeline: FluxPipeline, prompts, max_sequence_length=512):
35
+ # Turn off warnings (CLIP overflow)
36
+ logger.setLevel(logging.ERROR)
37
+ (
38
+ prompt_embeds,
39
+ pooled_prompt_embeds,
40
+ text_ids,
41
+ ) = pipeline.encode_prompt(
42
+ prompt=prompts,
43
+ prompt_2=None,
44
+ prompt_embeds=None,
45
+ pooled_prompt_embeds=None,
46
+ device=pipeline.device,
47
+ num_images_per_prompt=1,
48
+ max_sequence_length=max_sequence_length,
49
+ lora_scale=None,
50
+ )
51
+ # Turn on warnings
52
+ logger.setLevel(logging.WARNING)
53
+ return prompt_embeds, pooled_prompt_embeds, text_ids
flux/transformer.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Recycled from Ominicontrol and modified to accept an extra condition.
2
+ # While Zenctrl pursued a similar idea, it diverged structurally.
3
+ # We appreciate the clarity of Omini's implementation and decided to align with it.
4
+
5
+ import torch
6
+ from diffusers.pipelines import FluxPipeline
7
+ from typing import List, Union, Optional, Dict, Any, Callable
8
+ from .block import block_forward, single_block_forward
9
+ from .lora_controller import enable_lora
10
+ from accelerate.utils import is_torch_version
11
+ from diffusers.models.transformers.transformer_flux import (
12
+ FluxTransformer2DModel,
13
+ Transformer2DModelOutput,
14
+ USE_PEFT_BACKEND,
15
+ scale_lora_layers,
16
+ unscale_lora_layers,
17
+ logger,
18
+ )
19
+ import numpy as np
20
+
21
+
22
+ def prepare_params(
23
+ hidden_states: torch.Tensor,
24
+ encoder_hidden_states: torch.Tensor = None,
25
+ pooled_projections: torch.Tensor = None,
26
+ timestep: torch.LongTensor = None,
27
+ img_ids: torch.Tensor = None,
28
+ txt_ids: torch.Tensor = None,
29
+ guidance: torch.Tensor = None,
30
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
31
+ controlnet_block_samples=None,
32
+ controlnet_single_block_samples=None,
33
+ return_dict: bool = True,
34
+ **kwargs: dict,
35
+ ):
36
+ return (
37
+ hidden_states,
38
+ encoder_hidden_states,
39
+ pooled_projections,
40
+ timestep,
41
+ img_ids,
42
+ txt_ids,
43
+ guidance,
44
+ joint_attention_kwargs,
45
+ controlnet_block_samples,
46
+ controlnet_single_block_samples,
47
+ return_dict,
48
+ )
49
+
50
+
51
+ def tranformer_forward(
52
+ transformer: FluxTransformer2DModel,
53
+ condition_latents: torch.Tensor,
54
+ extra_condition_latents: torch.Tensor,
55
+ condition_ids: torch.Tensor,
56
+ condition_type_ids: torch.Tensor,
57
+ extra_condition_ids: torch.Tensor,
58
+ extra_condition_type_ids: torch.Tensor,
59
+ model_config: Optional[Dict[str, Any]] = {},
60
+ c_t=0,
61
+ **params: dict,
62
+ ):
63
+ self = transformer
64
+ use_condition = condition_latents is not None
65
+ use_extra_condition = extra_condition_latents is not None
66
+
67
+ (
68
+ hidden_states,
69
+ encoder_hidden_states,
70
+ pooled_projections,
71
+ timestep,
72
+ img_ids,
73
+ txt_ids,
74
+ guidance,
75
+ joint_attention_kwargs,
76
+ controlnet_block_samples,
77
+ controlnet_single_block_samples,
78
+ return_dict,
79
+ ) = prepare_params(**params)
80
+
81
+ if joint_attention_kwargs is not None:
82
+ joint_attention_kwargs = joint_attention_kwargs.copy()
83
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
84
+ else:
85
+ lora_scale = 1.0
86
+
87
+ if USE_PEFT_BACKEND:
88
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
89
+ scale_lora_layers(self, lora_scale)
90
+ else:
91
+ if (
92
+ joint_attention_kwargs is not None
93
+ and joint_attention_kwargs.get("scale", None) is not None
94
+ ):
95
+ logger.warning(
96
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
97
+ )
98
+
99
+ with enable_lora((self.x_embedder,), model_config.get("latent_lora", False)):
100
+ hidden_states = self.x_embedder(hidden_states)
101
+ condition_latents = self.x_embedder(condition_latents) if use_condition else None
102
+ extra_condition_latents = self.x_embedder(extra_condition_latents) if use_extra_condition else None
103
+
104
+ timestep = timestep.to(hidden_states.dtype) * 1000
105
+
106
+ if guidance is not None:
107
+ guidance = guidance.to(hidden_states.dtype) * 1000
108
+ else:
109
+ guidance = None
110
+
111
+ temb = (
112
+ self.time_text_embed(timestep, pooled_projections)
113
+ if guidance is None
114
+ else self.time_text_embed(timestep, guidance, pooled_projections)
115
+ )
116
+
117
+ cond_temb = (
118
+ self.time_text_embed(torch.ones_like(timestep) * c_t * 1000, pooled_projections)
119
+ if guidance is None
120
+ else self.time_text_embed(
121
+ torch.ones_like(timestep) * c_t * 1000, torch.ones_like(guidance) * 1000, pooled_projections
122
+ )
123
+ )
124
+ extra_cond_temb = (
125
+ self.time_text_embed(torch.ones_like(timestep) * c_t * 1000, pooled_projections)
126
+ if guidance is None
127
+ else self.time_text_embed(
128
+ torch.ones_like(timestep) * c_t * 1000, torch.ones_like(guidance) * 1000, pooled_projections
129
+ )
130
+ )
131
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
132
+
133
+ if txt_ids.ndim == 3:
134
+ logger.warning(
135
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
136
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
137
+ )
138
+ txt_ids = txt_ids[0]
139
+ if img_ids.ndim == 3:
140
+ logger.warning(
141
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
142
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
143
+ )
144
+ img_ids = img_ids[0]
145
+
146
+ ids = torch.cat((txt_ids, img_ids), dim=0)
147
+ image_rotary_emb = self.pos_embed(ids)
148
+ if use_condition:
149
+ # condition_ids[:, :1] = condition_type_ids
150
+ cond_rotary_emb = self.pos_embed(condition_ids)
151
+
152
+ if use_extra_condition:
153
+ extra_cond_rotary_emb = self.pos_embed(extra_condition_ids)
154
+
155
+
156
+ # hidden_states = torch.cat([hidden_states, condition_latents], dim=1)
157
+
158
+ #print("here!")
159
+ for index_block, block in enumerate(self.transformer_blocks):
160
+ if self.training and self.gradient_checkpointing:
161
+ ckpt_kwargs: Dict[str, Any] = (
162
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
163
+ )
164
+ encoder_hidden_states, hidden_states, condition_latents, extra_condition_latents = (
165
+ torch.utils.checkpoint.checkpoint(
166
+ block_forward,
167
+ self=block,
168
+ model_config=model_config,
169
+ hidden_states=hidden_states,
170
+ encoder_hidden_states=encoder_hidden_states,
171
+ condition_latents=condition_latents if use_condition else None,
172
+ extra_condition_latents=extra_condition_latents if use_extra_condition else None,
173
+ temb=temb,
174
+ cond_temb=cond_temb if use_condition else None,
175
+ cond_rotary_emb=cond_rotary_emb if use_condition else None,
176
+ extra_cond_temb=extra_cond_temb if use_extra_condition else None,
177
+ extra_cond_rotary_emb=extra_cond_rotary_emb if use_extra_condition else None,
178
+ image_rotary_emb=image_rotary_emb,
179
+ **ckpt_kwargs,
180
+ )
181
+ )
182
+
183
+ else:
184
+ encoder_hidden_states, hidden_states, condition_latents, extra_condition_latents = block_forward(
185
+ block,
186
+ model_config=model_config,
187
+ hidden_states=hidden_states,
188
+ encoder_hidden_states=encoder_hidden_states,
189
+ condition_latents=condition_latents if use_condition else None,
190
+ extra_condition_latents=extra_condition_latents if use_extra_condition else None,
191
+ temb=temb,
192
+ cond_temb=cond_temb if use_condition else None,
193
+ cond_rotary_emb=cond_rotary_emb if use_condition else None,
194
+ extra_cond_temb=cond_temb if use_extra_condition else None,
195
+ extra_cond_rotary_emb=extra_cond_rotary_emb if use_extra_condition else None,
196
+ image_rotary_emb=image_rotary_emb,
197
+ )
198
+
199
+ # controlnet residual
200
+ if controlnet_block_samples is not None:
201
+ interval_control = len(self.transformer_blocks) / len(
202
+ controlnet_block_samples
203
+ )
204
+ interval_control = int(np.ceil(interval_control))
205
+ hidden_states = (
206
+ hidden_states
207
+ + controlnet_block_samples[index_block // interval_control]
208
+ )
209
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
210
+
211
+
212
+ for index_block, block in enumerate(self.single_transformer_blocks):
213
+ if self.training and self.gradient_checkpointing:
214
+ ckpt_kwargs: Dict[str, Any] = (
215
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
216
+ )
217
+ result = torch.utils.checkpoint.checkpoint(
218
+ single_block_forward,
219
+ self=block,
220
+ model_config=model_config,
221
+ hidden_states=hidden_states,
222
+ temb=temb,
223
+ image_rotary_emb=image_rotary_emb,
224
+ **(
225
+ {
226
+ "condition_latents": condition_latents,
227
+ "extra_condition_latents": extra_condition_latents,
228
+ "cond_temb": cond_temb,
229
+ "cond_rotary_emb": cond_rotary_emb,
230
+ "extra_cond_temb": extra_cond_temb,
231
+ "extra_cond_rotary_emb": extra_cond_rotary_emb,
232
+ }
233
+ if use_condition
234
+ else {}
235
+ ),
236
+ **ckpt_kwargs,
237
+ )
238
+
239
+ else:
240
+ result = single_block_forward(
241
+ block,
242
+ model_config=model_config,
243
+ hidden_states=hidden_states,
244
+ temb=temb,
245
+ image_rotary_emb=image_rotary_emb,
246
+ **(
247
+ {
248
+ "condition_latents": condition_latents,
249
+ "extra_condition_latents": extra_condition_latents,
250
+ "cond_temb": cond_temb,
251
+ "cond_rotary_emb": cond_rotary_emb,
252
+ "extra_cond_temb": extra_cond_temb,
253
+ "extra_cond_rotary_emb": extra_cond_rotary_emb,
254
+ }
255
+ if use_condition
256
+ else {}
257
+ ),
258
+ )
259
+ if use_condition:
260
+ hidden_states, condition_latents, extra_condition_latents = result
261
+ else:
262
+ hidden_states = result
263
+
264
+ # controlnet residual
265
+ if controlnet_single_block_samples is not None:
266
+ interval_control = len(self.single_transformer_blocks) / len(
267
+ controlnet_single_block_samples
268
+ )
269
+ interval_control = int(np.ceil(interval_control))
270
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
271
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
272
+ + controlnet_single_block_samples[index_block // interval_control]
273
+ )
274
+
275
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
276
+
277
+ hidden_states = self.norm_out(hidden_states, temb)
278
+ output = self.proj_out(hidden_states)
279
+
280
+ if USE_PEFT_BACKEND:
281
+ # remove `lora_scale` from each PEFT layer
282
+ unscale_lora_layers(self, lora_scale)
283
+
284
+ if not return_dict:
285
+ return (output,)
286
+ return Transformer2DModelOutput(sample=output)