Update DiT.py
Browse files
DiT.py
CHANGED
|
@@ -55,13 +55,18 @@ class TimestepEmbedder:
|
|
| 55 |
return t_emb
|
| 56 |
|
| 57 |
|
| 58 |
-
class LabelEmbedder:
|
| 59 |
"""
|
| 60 |
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
| 61 |
"""
|
| 62 |
def __init__(self, num_classes, hidden_size, dropout_prob):
|
| 63 |
use_cfg_embedding = dropout_prob > 0
|
| 64 |
-
self.embedding_table =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
self.num_classes = num_classes
|
| 66 |
self.dropout_prob = dropout_prob
|
| 67 |
|
|
@@ -156,7 +161,12 @@ class DiT(Model):
|
|
| 156 |
self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
|
| 157 |
num_patches = self.x_embedder.num_patches
|
| 158 |
# Will use fixed sin-cos embedding:
|
| 159 |
-
self.pos_embed =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
|
| 161 |
self.blocks = [
|
| 162 |
DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
|
|
@@ -167,8 +177,7 @@ class DiT(Model):
|
|
| 167 |
def initialize_weights(self):
|
| 168 |
# Initialize (and freeze) pos_embed by sin-cos embedding:
|
| 169 |
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
|
| 170 |
-
self.pos_embed
|
| 171 |
-
tf.Variable(self.pos_embed)
|
| 172 |
|
| 173 |
def unpatchify(self, x):
|
| 174 |
"""
|
|
|
|
| 55 |
return t_emb
|
| 56 |
|
| 57 |
|
| 58 |
+
class LabelEmbedder(tf.keras.layers.Layer):
|
| 59 |
"""
|
| 60 |
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
| 61 |
"""
|
| 62 |
def __init__(self, num_classes, hidden_size, dropout_prob):
|
| 63 |
use_cfg_embedding = dropout_prob > 0
|
| 64 |
+
self.embedding_table = self.add_weight(
|
| 65 |
+
name='embedding_table',
|
| 66 |
+
shape=(num_classes + use_cfg_embedding, hidden_size),
|
| 67 |
+
initializer=tf.keras.initializers.RandomNormal(stddev=0.02),
|
| 68 |
+
trainable=True
|
| 69 |
+
)
|
| 70 |
self.num_classes = num_classes
|
| 71 |
self.dropout_prob = dropout_prob
|
| 72 |
|
|
|
|
| 161 |
self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
|
| 162 |
num_patches = self.x_embedder.num_patches
|
| 163 |
# Will use fixed sin-cos embedding:
|
| 164 |
+
self.pos_embed = self.add_weight(
|
| 165 |
+
name='pos_embed',
|
| 166 |
+
shape=(1, num_patches, hidden_size),
|
| 167 |
+
initializer=tf.keras.initializers.Zeros(),
|
| 168 |
+
trainable=False # To freeze this variable
|
| 169 |
+
)
|
| 170 |
|
| 171 |
self.blocks = [
|
| 172 |
DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
|
|
|
|
| 177 |
def initialize_weights(self):
|
| 178 |
# Initialize (and freeze) pos_embed by sin-cos embedding:
|
| 179 |
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
|
| 180 |
+
self.pos_embed.assign(tf.convert_to_tensor(pos_embed, dtype=tf.float32)[tf.newaxis, :])
|
|
|
|
| 181 |
|
| 182 |
def unpatchify(self, x):
|
| 183 |
"""
|