Vikram Voleti commited on
Commit
74371ce
·
1 Parent(s): 9af5ae9

SD3.5 Medium

Browse files
Files changed (2) hide show
  1. mmditx.py +6 -6
  2. sd3_infer.py +17 -5
mmditx.py CHANGED
@@ -583,7 +583,7 @@ class DismantledBlock(nn.Module):
583
  modulate(self.norm2(x), shift_mlp, scale_mlp)
584
  )
585
  x = x + mlp_
586
- return x, (gate_msa, gate_msa2, gate_mlp, attn_, attn2_)
587
 
588
  def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
589
  assert not self.pre_only
@@ -607,11 +607,10 @@ def block_mixing(context, x, context_block, x_block, c):
607
  else:
608
  x_qkv, x_intermediates = x_block.pre_attention(x, c)
609
 
610
- o = []
611
- for t in range(3):
612
- o.append(torch.cat((context_qkv[t], x_qkv[t]), dim=1))
613
- q, k, v = tuple(o)
614
-
615
  attn = attention(q, k, v, x_block.attn.num_heads)
616
  context_attn, x_attn = (
617
  attn[:, : context_qkv[0].shape[1]],
@@ -626,6 +625,7 @@ def block_mixing(context, x, context_block, x_block, c):
626
  if x_block.x_block_self_attn:
627
  x_q2, x_k2, x_v2 = x_qkv2
628
  attn2 = attention(x_q2, x_k2, x_v2, x_block.attn2.num_heads)
 
629
  else:
630
  x = x_block.post_attention(x_attn, *x_intermediates)
631
 
 
583
  modulate(self.norm2(x), shift_mlp, scale_mlp)
584
  )
585
  x = x + mlp_
586
+ return x
587
 
588
  def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
589
  assert not self.pre_only
 
607
  else:
608
  x_qkv, x_intermediates = x_block.pre_attention(x, c)
609
 
610
+ q, k, v = tuple(
611
+ torch.cat(tuple(qkv[i] for qkv in [context_qkv, x_qkv]), dim=1)
612
+ for i in range(3)
613
+ )
 
614
  attn = attention(q, k, v, x_block.attn.num_heads)
615
  context_attn, x_attn = (
616
  attn[:, : context_qkv[0].shape[1]],
 
625
  if x_block.x_block_self_attn:
626
  x_q2, x_k2, x_v2 = x_qkv2
627
  attn2 = attention(x_q2, x_k2, x_v2, x_block.attn2.num_heads)
628
+ x = x_block.post_attention_x(x_attn, attn2, *x_intermediates)
629
  else:
630
  x = x_block.post_attention(x_attn, *x_intermediates)
631
 
sd3_infer.py CHANGED
@@ -363,6 +363,12 @@ CONFIGS = {
363
  "steps": 50,
364
  "sampler": "dpmpp_2m",
365
  },
 
 
 
 
 
 
366
  "sd3.5_large": {
367
  "shift": 3.0,
368
  "cfg": 4.5,
@@ -392,12 +398,18 @@ def main(
392
  denoise=DENOISE,
393
  verbose=False,
394
  ):
395
- steps = steps or CONFIGS[os.path.splitext(os.path.basename(model))[0]]["steps"]
396
- cfg = cfg or CONFIGS[os.path.splitext(os.path.basename(model))[0]]["cfg"]
397
- shift = shift or CONFIGS[os.path.splitext(os.path.basename(model))[0]]["shift"]
398
- sampler = (
399
- sampler or CONFIGS[os.path.splitext(os.path.basename(model))[0]]["sampler"]
 
 
 
400
  )
 
 
 
401
 
402
  inferencer = SD3Inferencer()
403
  inferencer.load(model, vae, shift, verbose)
 
363
  "steps": 50,
364
  "sampler": "dpmpp_2m",
365
  },
366
+ "sd3.5_medium": {
367
+ "shift": 3.0,
368
+ "cfg": 5.0,
369
+ "steps": 50,
370
+ "sampler": "dpmpp_2m",
371
+ },
372
  "sd3.5_large": {
373
  "shift": 3.0,
374
  "cfg": 4.5,
 
398
  denoise=DENOISE,
399
  verbose=False,
400
  ):
401
+ steps = steps or CONFIGS.get(os.path.splitext(os.path.basename(model))[0], {}).get(
402
+ "steps", 50
403
+ )
404
+ cfg = cfg or CONFIGS.get(os.path.splitext(os.path.basename(model))[0], {}).get(
405
+ "cfg", 5
406
+ )
407
+ shift = shift or CONFIGS.get(os.path.splitext(os.path.basename(model))[0], {}).get(
408
+ "shift", 3
409
  )
410
+ sampler = sampler or CONFIGS.get(
411
+ os.path.splitext(os.path.basename(model))[0], {}
412
+ ).get("sampler", "dpmpp_2m")
413
 
414
  inferencer = SD3Inferencer()
415
  inferencer.load(model, vae, shift, verbose)