frankleeeee
commited on
Commit
•
3065798
1
Parent(s):
f4ad4f1
Upload STDiT
Browse files- config.json +2 -1
- modeling_stdit.py +3 -2
config.json
CHANGED
@@ -11,6 +11,7 @@
|
|
11 |
"depth": 28,
|
12 |
"drop_path": 0.0,
|
13 |
"enable_flash_attn": false,
|
|
|
14 |
"enable_layernorm_kernel": false,
|
15 |
"enable_sequence_parallelism": false,
|
16 |
"freeze": null,
|
@@ -32,7 +33,7 @@
|
|
32 |
2
|
33 |
],
|
34 |
"pred_sigma": true,
|
35 |
-
"space_scale": 0
|
36 |
"time_scale": 1.0,
|
37 |
"torch_dtype": "float32",
|
38 |
"transformers_version": "4.38.2"
|
|
|
11 |
"depth": 28,
|
12 |
"drop_path": 0.0,
|
13 |
"enable_flash_attn": false,
|
14 |
+
"enable_flashattn": false,
|
15 |
"enable_layernorm_kernel": false,
|
16 |
"enable_sequence_parallelism": false,
|
17 |
"freeze": null,
|
|
|
33 |
2
|
34 |
],
|
35 |
"pred_sigma": true,
|
36 |
+
"space_scale": 1.0,
|
37 |
"time_scale": 1.0,
|
38 |
"torch_dtype": "float32",
|
39 |
"transformers_version": "4.38.2"
|
modeling_stdit.py
CHANGED
@@ -112,7 +112,7 @@ class STDiT(PreTrainedModel):
|
|
112 |
x = x.to(self.final_layer.linear.weight.dtype)
|
113 |
timestep = timestep.to(self.final_layer.linear.weight.dtype)
|
114 |
y = y.to(self.final_layer.linear.weight.dtype)
|
115 |
-
|
116 |
# embedding
|
117 |
x = self.x_embedder(x) # [B, N, C]
|
118 |
x = rearrange(x, "B (T S) C -> B T S C", T=self.num_temporal, S=self.num_spatial)
|
@@ -148,7 +148,8 @@ class STDiT(PreTrainedModel):
|
|
148 |
tpe = self.pos_embed_temporal
|
149 |
else:
|
150 |
tpe = None
|
151 |
-
x =
|
|
|
152 |
|
153 |
if self.enable_sequence_parallelism:
|
154 |
x = gather_forward_split_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="up")
|
|
|
112 |
x = x.to(self.final_layer.linear.weight.dtype)
|
113 |
timestep = timestep.to(self.final_layer.linear.weight.dtype)
|
114 |
y = y.to(self.final_layer.linear.weight.dtype)
|
115 |
+
|
116 |
# embedding
|
117 |
x = self.x_embedder(x) # [B, N, C]
|
118 |
x = rearrange(x, "B (T S) C -> B T S C", T=self.num_temporal, S=self.num_spatial)
|
|
|
148 |
tpe = self.pos_embed_temporal
|
149 |
else:
|
150 |
tpe = None
|
151 |
+
x = block(x, y, t0, y_lens, tpe)
|
152 |
+
# x = auto_grad_checkpoint(block, x, y, t0, y_lens, tpe)
|
153 |
|
154 |
if self.enable_sequence_parallelism:
|
155 |
x = gather_forward_split_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="up")
|