Vikram Voleti
commited on
Commit
·
74371ce
1
Parent(s):
9af5ae9
SD3.5 Medium
Browse files- mmditx.py +6 -6
- sd3_infer.py +17 -5
mmditx.py
CHANGED
@@ -583,7 +583,7 @@ class DismantledBlock(nn.Module):
|
|
583 |
modulate(self.norm2(x), shift_mlp, scale_mlp)
|
584 |
)
|
585 |
x = x + mlp_
|
586 |
-
return x
|
587 |
|
588 |
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
589 |
assert not self.pre_only
|
@@ -607,11 +607,10 @@ def block_mixing(context, x, context_block, x_block, c):
|
|
607 |
else:
|
608 |
x_qkv, x_intermediates = x_block.pre_attention(x, c)
|
609 |
|
610 |
-
|
611 |
-
|
612 |
-
|
613 |
-
|
614 |
-
|
615 |
attn = attention(q, k, v, x_block.attn.num_heads)
|
616 |
context_attn, x_attn = (
|
617 |
attn[:, : context_qkv[0].shape[1]],
|
@@ -626,6 +625,7 @@ def block_mixing(context, x, context_block, x_block, c):
|
|
626 |
if x_block.x_block_self_attn:
|
627 |
x_q2, x_k2, x_v2 = x_qkv2
|
628 |
attn2 = attention(x_q2, x_k2, x_v2, x_block.attn2.num_heads)
|
|
|
629 |
else:
|
630 |
x = x_block.post_attention(x_attn, *x_intermediates)
|
631 |
|
|
|
583 |
modulate(self.norm2(x), shift_mlp, scale_mlp)
|
584 |
)
|
585 |
x = x + mlp_
|
586 |
+
return x
|
587 |
|
588 |
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
589 |
assert not self.pre_only
|
|
|
607 |
else:
|
608 |
x_qkv, x_intermediates = x_block.pre_attention(x, c)
|
609 |
|
610 |
+
q, k, v = tuple(
|
611 |
+
torch.cat(tuple(qkv[i] for qkv in [context_qkv, x_qkv]), dim=1)
|
612 |
+
for i in range(3)
|
613 |
+
)
|
|
|
614 |
attn = attention(q, k, v, x_block.attn.num_heads)
|
615 |
context_attn, x_attn = (
|
616 |
attn[:, : context_qkv[0].shape[1]],
|
|
|
625 |
if x_block.x_block_self_attn:
|
626 |
x_q2, x_k2, x_v2 = x_qkv2
|
627 |
attn2 = attention(x_q2, x_k2, x_v2, x_block.attn2.num_heads)
|
628 |
+
x = x_block.post_attention_x(x_attn, attn2, *x_intermediates)
|
629 |
else:
|
630 |
x = x_block.post_attention(x_attn, *x_intermediates)
|
631 |
|
sd3_infer.py
CHANGED
@@ -363,6 +363,12 @@ CONFIGS = {
|
|
363 |
"steps": 50,
|
364 |
"sampler": "dpmpp_2m",
|
365 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
366 |
"sd3.5_large": {
|
367 |
"shift": 3.0,
|
368 |
"cfg": 4.5,
|
@@ -392,12 +398,18 @@ def main(
|
|
392 |
denoise=DENOISE,
|
393 |
verbose=False,
|
394 |
):
|
395 |
-
steps = steps or CONFIGS
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
|
|
|
|
|
|
400 |
)
|
|
|
|
|
|
|
401 |
|
402 |
inferencer = SD3Inferencer()
|
403 |
inferencer.load(model, vae, shift, verbose)
|
|
|
363 |
"steps": 50,
|
364 |
"sampler": "dpmpp_2m",
|
365 |
},
|
366 |
+
"sd3.5_medium": {
|
367 |
+
"shift": 3.0,
|
368 |
+
"cfg": 5.0,
|
369 |
+
"steps": 50,
|
370 |
+
"sampler": "dpmpp_2m",
|
371 |
+
},
|
372 |
"sd3.5_large": {
|
373 |
"shift": 3.0,
|
374 |
"cfg": 4.5,
|
|
|
398 |
denoise=DENOISE,
|
399 |
verbose=False,
|
400 |
):
|
401 |
+
steps = steps or CONFIGS.get(os.path.splitext(os.path.basename(model))[0], {}).get(
|
402 |
+
"steps", 50
|
403 |
+
)
|
404 |
+
cfg = cfg or CONFIGS.get(os.path.splitext(os.path.basename(model))[0], {}).get(
|
405 |
+
"cfg", 5
|
406 |
+
)
|
407 |
+
shift = shift or CONFIGS.get(os.path.splitext(os.path.basename(model))[0], {}).get(
|
408 |
+
"shift", 3
|
409 |
)
|
410 |
+
sampler = sampler or CONFIGS.get(
|
411 |
+
os.path.splitext(os.path.basename(model))[0], {}
|
412 |
+
).get("sampler", "dpmpp_2m")
|
413 |
|
414 |
inferencer = SD3Inferencer()
|
415 |
inferencer.load(model, vae, shift, verbose)
|