frankleeeee
commited on
Commit
•
f4ad4f1
1
Parent(s):
9637da1
Update modeling_stdit.py
Browse files- modeling_stdit.py +4 -0
modeling_stdit.py
CHANGED
@@ -109,6 +109,10 @@ class STDiT(PreTrainedModel):
|
|
109 |
Returns:
|
110 |
x (torch.Tensor): output latent representation; of shape [B, C, T, H, W]
|
111 |
"""
|
|
|
|
|
|
|
|
|
112 |
# embedding
|
113 |
x = self.x_embedder(x) # [B, N, C]
|
114 |
x = rearrange(x, "B (T S) C -> B T S C", T=self.num_temporal, S=self.num_spatial)
|
|
|
109 |
Returns:
|
110 |
x (torch.Tensor): output latent representation; of shape [B, C, T, H, W]
|
111 |
"""
|
112 |
+
x = x.to(self.final_layer.linear.weight.dtype)
|
113 |
+
timestep = timestep.to(self.final_layer.linear.weight.dtype)
|
114 |
+
y = y.to(self.final_layer.linear.weight.dtype)
|
115 |
+
|
116 |
# embedding
|
117 |
x = self.x_embedder(x) # [B, N, C]
|
118 |
x = rearrange(x, "B (T S) C -> B T S C", T=self.num_temporal, S=self.num_spatial)
|