frankleeeee commited on
Commit
3065798
1 Parent(s): f4ad4f1

Upload STDiT

Browse files
Files changed (2) hide show
  1. config.json +2 -1
  2. modeling_stdit.py +3 -2
config.json CHANGED
@@ -11,6 +11,7 @@
11
  "depth": 28,
12
  "drop_path": 0.0,
13
  "enable_flash_attn": false,
 
14
  "enable_layernorm_kernel": false,
15
  "enable_sequence_parallelism": false,
16
  "freeze": null,
@@ -32,7 +33,7 @@
32
  2
33
  ],
34
  "pred_sigma": true,
35
- "space_scale": 0.5,
36
  "time_scale": 1.0,
37
  "torch_dtype": "float32",
38
  "transformers_version": "4.38.2"
 
11
  "depth": 28,
12
  "drop_path": 0.0,
13
  "enable_flash_attn": false,
14
+ "enable_flashattn": false,
15
  "enable_layernorm_kernel": false,
16
  "enable_sequence_parallelism": false,
17
  "freeze": null,
 
33
  2
34
  ],
35
  "pred_sigma": true,
36
+ "space_scale": 1.0,
37
  "time_scale": 1.0,
38
  "torch_dtype": "float32",
39
  "transformers_version": "4.38.2"
modeling_stdit.py CHANGED
@@ -112,7 +112,7 @@ class STDiT(PreTrainedModel):
112
  x = x.to(self.final_layer.linear.weight.dtype)
113
  timestep = timestep.to(self.final_layer.linear.weight.dtype)
114
  y = y.to(self.final_layer.linear.weight.dtype)
115
-
116
  # embedding
117
  x = self.x_embedder(x) # [B, N, C]
118
  x = rearrange(x, "B (T S) C -> B T S C", T=self.num_temporal, S=self.num_spatial)
@@ -148,7 +148,8 @@ class STDiT(PreTrainedModel):
148
  tpe = self.pos_embed_temporal
149
  else:
150
  tpe = None
151
- x = auto_grad_checkpoint(block, x, y, t0, y_lens, tpe)
 
152
 
153
  if self.enable_sequence_parallelism:
154
  x = gather_forward_split_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="up")
 
112
  x = x.to(self.final_layer.linear.weight.dtype)
113
  timestep = timestep.to(self.final_layer.linear.weight.dtype)
114
  y = y.to(self.final_layer.linear.weight.dtype)
115
+
116
  # embedding
117
  x = self.x_embedder(x) # [B, N, C]
118
  x = rearrange(x, "B (T S) C -> B T S C", T=self.num_temporal, S=self.num_spatial)
 
148
  tpe = self.pos_embed_temporal
149
  else:
150
  tpe = None
151
+ x = block(x, y, t0, y_lens, tpe)
152
+ # x = auto_grad_checkpoint(block, x, y, t0, y_lens, tpe)
153
 
154
  if self.enable_sequence_parallelism:
155
  x = gather_forward_split_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="up")