OpenSora-STDiT-v2-stage3 / configuration_stdit2.py
frankleeeee's picture
Upload STDiT2
3074d7c verified
import torch
from transformers import PretrainedConfig
class STDiT2Config(PretrainedConfig):
model_type = "stdit2"
def __init__(
self,
input_size=(None, None, None),
input_sq_size=32,
in_channels=4,
patch_size=(1, 2, 2),
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4.0,
class_dropout_prob=0.1,
pred_sigma=True,
drop_path=0.0,
no_temporal_pos_emb=False,
caption_channels=4096,
model_max_length=120,
freeze=None,
qk_norm=False,
enable_flash_attn=False,
enable_layernorm_kernel=False,
enable_sequence_parallelism=False,
**kwargs,
):
self.input_size = input_size
self.input_sq_size = input_sq_size
self.in_channels = in_channels
self.patch_size = patch_size
self.hidden_size = hidden_size
self.depth = depth
self.num_heads = num_heads
self.mlp_ratio = mlp_ratio
self.class_dropout_prob = class_dropout_prob
self.pred_sigma = pred_sigma
self.drop_path = drop_path
self.no_temporal_pos_emb = no_temporal_pos_emb
self.caption_channels = caption_channels
self.model_max_length = model_max_length
self.freeze = freeze
self.qk_norm = qk_norm
self.enable_flash_attn = enable_flash_attn
self.enable_layernorm_kernel = enable_layernorm_kernel
self.enable_sequence_parallelism = enable_sequence_parallelism
super().__init__(**kwargs)