frankleeeee commited on
Commit
f4ad4f1
1 Parent(s): 9637da1

Update modeling_stdit.py

Browse files
Files changed (1) hide show
  1. modeling_stdit.py +4 -0
modeling_stdit.py CHANGED
@@ -109,6 +109,10 @@ class STDiT(PreTrainedModel):
109
  Returns:
110
  x (torch.Tensor): output latent representation; of shape [B, C, T, H, W]
111
  """
 
 
 
 
112
  # embedding
113
  x = self.x_embedder(x) # [B, N, C]
114
  x = rearrange(x, "B (T S) C -> B T S C", T=self.num_temporal, S=self.num_spatial)
 
109
  Returns:
110
  x (torch.Tensor): output latent representation; of shape [B, C, T, H, W]
111
  """
112
+ x = x.to(self.final_layer.linear.weight.dtype)
113
+ timestep = timestep.to(self.final_layer.linear.weight.dtype)
114
+ y = y.to(self.final_layer.linear.weight.dtype)
115
+
116
  # embedding
117
  x = self.x_embedder(x) # [B, N, C]
118
  x = rearrange(x, "B (T S) C -> B T S C", T=self.num_temporal, S=self.num_spatial)