BAAI
/

Any-to-Any
Diffusers
Safetensors
OmniGen2Pipeline
sienna223 commited on
Commit
e102b09
·
verified ·
1 Parent(s): 0d71403

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -39,3 +39,4 @@ assets/examples_edit.png filter=lfs diff=lfs merge=lfs -text
39
  assets/examples_subject.png filter=lfs diff=lfs merge=lfs -text
40
  assets/teaser.jpg filter=lfs diff=lfs merge=lfs -text
41
  assets/teaser.png filter=lfs diff=lfs merge=lfs -text
 
 
39
  assets/examples_subject.png filter=lfs diff=lfs merge=lfs -text
40
  assets/teaser.jpg filter=lfs diff=lfs merge=lfs -text
41
  assets/teaser.png filter=lfs diff=lfs merge=lfs -text
42
+ processor/tokenizer.json filter=lfs diff=lfs merge=lfs -text
mllm/config.json ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Qwen2_5_VLForConditionalGeneration"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "bos_token_id": 151643,
7
+ "eos_token_id": 151645,
8
+ "hidden_act": "silu",
9
+ "hidden_size": 2048,
10
+ "image_token_id": 151655,
11
+ "initializer_range": 0.02,
12
+ "intermediate_size": 11008,
13
+ "max_position_embeddings": 128000,
14
+ "max_window_layers": 70,
15
+ "model_type": "qwen2_5_vl",
16
+ "num_attention_heads": 16,
17
+ "num_hidden_layers": 36,
18
+ "num_key_value_heads": 2,
19
+ "rms_norm_eps": 1e-06,
20
+ "rope_scaling": {
21
+ "mrope_section": [
22
+ 16,
23
+ 24,
24
+ 24
25
+ ],
26
+ "rope_type": "default",
27
+ "type": "default"
28
+ },
29
+ "rope_theta": 1000000.0,
30
+ "sliding_window": 32768,
31
+ "tie_word_embeddings": true,
32
+ "torch_dtype": "float32",
33
+ "transformers_version": "4.51.3",
34
+ "use_cache": true,
35
+ "use_sliding_window": false,
36
+ "video_token_id": 151656,
37
+ "vision_config": {
38
+ "depth": 32,
39
+ "fullatt_block_indexes": [
40
+ 7,
41
+ 15,
42
+ 23,
43
+ 31
44
+ ],
45
+ "hidden_act": "silu",
46
+ "hidden_size": 1280,
47
+ "in_channels": 3,
48
+ "in_chans": 3,
49
+ "intermediate_size": 3420,
50
+ "model_type": "qwen2_5_vl",
51
+ "num_heads": 16,
52
+ "out_hidden_size": 2048,
53
+ "patch_size": 14,
54
+ "spatial_merge_size": 2,
55
+ "spatial_patch_size": 14,
56
+ "temporal_patch_size": 2,
57
+ "tokens_per_second": 2,
58
+ "torch_dtype": "float32",
59
+ "window_size": 112
60
+ },
61
+ "vision_end_token_id": 151653,
62
+ "vision_start_token_id": 151652,
63
+ "vision_token_id": 151654,
64
+ "vocab_size": 151936
65
+ }
mllm/generation_config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_attn_implementation": "flash_attention_2",
3
+ "bos_token_id": 151643,
4
+ "do_sample": true,
5
+ "eos_token_id": [
6
+ 151645,
7
+ 151643
8
+ ],
9
+ "pad_token_id": 151643,
10
+ "repetition_penalty": 1.05,
11
+ "temperature": 0.1,
12
+ "top_k": 1,
13
+ "top_p": 0.001,
14
+ "transformers_version": "4.51.3"
15
+ }
mllm/model-00001-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bfe5503539cf9335017bca96cd4409cec234de71af6bbe9a6035c0952e9319c2
3
+ size 4972304384
mllm/model-00002-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:09b6f098d029dd4fadcc032d2d9da1914204d098c04a5a38f19bba23beb3039c
3
+ size 4932949248
mllm/model-00003-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fa99a4bf5f586e0f6aa2e941a66283afc98df2a8c726496f717ff7939ec70651
3
+ size 4932949336
mllm/model-00004-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e9bdab213ea510b2de9805ae054076d804fe578eb052460ed7988c8cc3aee114
3
+ size 180380208
mllm/model.safetensors.index.json ADDED
@@ -0,0 +1,831 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 15018491904
4
+ },
5
+ "weight_map": {
6
+ "model.embed_tokens.weight": "model-00001-of-00004.safetensors",
7
+ "model.layers.0.input_layernorm.weight": "model-00001-of-00004.safetensors",
8
+ "model.layers.0.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
9
+ "model.layers.0.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
10
+ "model.layers.0.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
11
+ "model.layers.0.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
12
+ "model.layers.0.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
13
+ "model.layers.0.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
14
+ "model.layers.0.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
15
+ "model.layers.0.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
16
+ "model.layers.0.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
17
+ "model.layers.0.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
18
+ "model.layers.0.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
19
+ "model.layers.1.input_layernorm.weight": "model-00001-of-00004.safetensors",
20
+ "model.layers.1.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
21
+ "model.layers.1.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
22
+ "model.layers.1.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
23
+ "model.layers.1.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
24
+ "model.layers.1.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
25
+ "model.layers.1.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
26
+ "model.layers.1.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
27
+ "model.layers.1.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
28
+ "model.layers.1.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
29
+ "model.layers.1.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
30
+ "model.layers.1.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
31
+ "model.layers.10.input_layernorm.weight": "model-00002-of-00004.safetensors",
32
+ "model.layers.10.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
33
+ "model.layers.10.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
34
+ "model.layers.10.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
35
+ "model.layers.10.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
36
+ "model.layers.10.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
37
+ "model.layers.10.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
38
+ "model.layers.10.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
39
+ "model.layers.10.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
40
+ "model.layers.10.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
41
+ "model.layers.10.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
42
+ "model.layers.10.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
43
+ "model.layers.11.input_layernorm.weight": "model-00002-of-00004.safetensors",
44
+ "model.layers.11.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
45
+ "model.layers.11.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
46
+ "model.layers.11.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
47
+ "model.layers.11.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
48
+ "model.layers.11.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
49
+ "model.layers.11.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
50
+ "model.layers.11.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
51
+ "model.layers.11.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
52
+ "model.layers.11.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
53
+ "model.layers.11.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
54
+ "model.layers.11.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
55
+ "model.layers.12.input_layernorm.weight": "model-00002-of-00004.safetensors",
56
+ "model.layers.12.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
57
+ "model.layers.12.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
58
+ "model.layers.12.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
59
+ "model.layers.12.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
60
+ "model.layers.12.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
61
+ "model.layers.12.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
62
+ "model.layers.12.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
63
+ "model.layers.12.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
64
+ "model.layers.12.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
65
+ "model.layers.12.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
66
+ "model.layers.12.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
67
+ "model.layers.13.input_layernorm.weight": "model-00002-of-00004.safetensors",
68
+ "model.layers.13.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
69
+ "model.layers.13.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
70
+ "model.layers.13.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
71
+ "model.layers.13.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
72
+ "model.layers.13.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
73
+ "model.layers.13.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
74
+ "model.layers.13.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
75
+ "model.layers.13.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
76
+ "model.layers.13.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
77
+ "model.layers.13.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
78
+ "model.layers.13.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
79
+ "model.layers.14.input_layernorm.weight": "model-00002-of-00004.safetensors",
80
+ "model.layers.14.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
81
+ "model.layers.14.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
82
+ "model.layers.14.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
83
+ "model.layers.14.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
84
+ "model.layers.14.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
85
+ "model.layers.14.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
86
+ "model.layers.14.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
87
+ "model.layers.14.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
88
+ "model.layers.14.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
89
+ "model.layers.14.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
90
+ "model.layers.14.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
91
+ "model.layers.15.input_layernorm.weight": "model-00002-of-00004.safetensors",
92
+ "model.layers.15.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
93
+ "model.layers.15.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
94
+ "model.layers.15.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
95
+ "model.layers.15.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
96
+ "model.layers.15.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
97
+ "model.layers.15.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
98
+ "model.layers.15.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
99
+ "model.layers.15.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
100
+ "model.layers.15.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
101
+ "model.layers.15.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
102
+ "model.layers.15.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
103
+ "model.layers.16.input_layernorm.weight": "model-00002-of-00004.safetensors",
104
+ "model.layers.16.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
105
+ "model.layers.16.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
106
+ "model.layers.16.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
107
+ "model.layers.16.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
108
+ "model.layers.16.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
109
+ "model.layers.16.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
110
+ "model.layers.16.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
111
+ "model.layers.16.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
112
+ "model.layers.16.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
113
+ "model.layers.16.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
114
+ "model.layers.16.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
115
+ "model.layers.17.input_layernorm.weight": "model-00002-of-00004.safetensors",
116
+ "model.layers.17.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
117
+ "model.layers.17.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
118
+ "model.layers.17.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
119
+ "model.layers.17.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
120
+ "model.layers.17.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
121
+ "model.layers.17.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
122
+ "model.layers.17.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
123
+ "model.layers.17.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
124
+ "model.layers.17.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
125
+ "model.layers.17.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
126
+ "model.layers.17.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
127
+ "model.layers.18.input_layernorm.weight": "model-00002-of-00004.safetensors",
128
+ "model.layers.18.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
129
+ "model.layers.18.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
130
+ "model.layers.18.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
131
+ "model.layers.18.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
132
+ "model.layers.18.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
133
+ "model.layers.18.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
134
+ "model.layers.18.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
135
+ "model.layers.18.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
136
+ "model.layers.18.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
137
+ "model.layers.18.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
138
+ "model.layers.18.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
139
+ "model.layers.19.input_layernorm.weight": "model-00003-of-00004.safetensors",
140
+ "model.layers.19.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
141
+ "model.layers.19.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
142
+ "model.layers.19.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
143
+ "model.layers.19.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
144
+ "model.layers.19.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
145
+ "model.layers.19.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
146
+ "model.layers.19.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
147
+ "model.layers.19.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
148
+ "model.layers.19.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
149
+ "model.layers.19.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
150
+ "model.layers.19.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
151
+ "model.layers.2.input_layernorm.weight": "model-00001-of-00004.safetensors",
152
+ "model.layers.2.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
153
+ "model.layers.2.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
154
+ "model.layers.2.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
155
+ "model.layers.2.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
156
+ "model.layers.2.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
157
+ "model.layers.2.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
158
+ "model.layers.2.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
159
+ "model.layers.2.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
160
+ "model.layers.2.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
161
+ "model.layers.2.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
162
+ "model.layers.2.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
163
+ "model.layers.20.input_layernorm.weight": "model-00003-of-00004.safetensors",
164
+ "model.layers.20.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
165
+ "model.layers.20.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
166
+ "model.layers.20.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
167
+ "model.layers.20.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
168
+ "model.layers.20.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
169
+ "model.layers.20.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
170
+ "model.layers.20.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
171
+ "model.layers.20.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
172
+ "model.layers.20.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
173
+ "model.layers.20.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
174
+ "model.layers.20.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
175
+ "model.layers.21.input_layernorm.weight": "model-00003-of-00004.safetensors",
176
+ "model.layers.21.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
177
+ "model.layers.21.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
178
+ "model.layers.21.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
179
+ "model.layers.21.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
180
+ "model.layers.21.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
181
+ "model.layers.21.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
182
+ "model.layers.21.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
183
+ "model.layers.21.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
184
+ "model.layers.21.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
185
+ "model.layers.21.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
186
+ "model.layers.21.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
187
+ "model.layers.22.input_layernorm.weight": "model-00003-of-00004.safetensors",
188
+ "model.layers.22.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
189
+ "model.layers.22.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
190
+ "model.layers.22.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
191
+ "model.layers.22.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
192
+ "model.layers.22.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
193
+ "model.layers.22.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
194
+ "model.layers.22.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
195
+ "model.layers.22.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
196
+ "model.layers.22.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
197
+ "model.layers.22.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
198
+ "model.layers.22.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
199
+ "model.layers.23.input_layernorm.weight": "model-00003-of-00004.safetensors",
200
+ "model.layers.23.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
201
+ "model.layers.23.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
202
+ "model.layers.23.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
203
+ "model.layers.23.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
204
+ "model.layers.23.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
205
+ "model.layers.23.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
206
+ "model.layers.23.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
207
+ "model.layers.23.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
208
+ "model.layers.23.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
209
+ "model.layers.23.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
210
+ "model.layers.23.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
211
+ "model.layers.24.input_layernorm.weight": "model-00003-of-00004.safetensors",
212
+ "model.layers.24.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
213
+ "model.layers.24.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
214
+ "model.layers.24.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
215
+ "model.layers.24.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
216
+ "model.layers.24.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
217
+ "model.layers.24.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
218
+ "model.layers.24.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
219
+ "model.layers.24.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
220
+ "model.layers.24.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
221
+ "model.layers.24.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
222
+ "model.layers.24.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
223
+ "model.layers.25.input_layernorm.weight": "model-00003-of-00004.safetensors",
224
+ "model.layers.25.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
225
+ "model.layers.25.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
226
+ "model.layers.25.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
227
+ "model.layers.25.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
228
+ "model.layers.25.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
229
+ "model.layers.25.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
230
+ "model.layers.25.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
231
+ "model.layers.25.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
232
+ "model.layers.25.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
233
+ "model.layers.25.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
234
+ "model.layers.25.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
235
+ "model.layers.26.input_layernorm.weight": "model-00003-of-00004.safetensors",
236
+ "model.layers.26.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
237
+ "model.layers.26.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
238
+ "model.layers.26.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
239
+ "model.layers.26.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
240
+ "model.layers.26.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
241
+ "model.layers.26.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
242
+ "model.layers.26.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
243
+ "model.layers.26.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
244
+ "model.layers.26.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
245
+ "model.layers.26.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
246
+ "model.layers.26.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
247
+ "model.layers.27.input_layernorm.weight": "model-00003-of-00004.safetensors",
248
+ "model.layers.27.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
249
+ "model.layers.27.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
250
+ "model.layers.27.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
251
+ "model.layers.27.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
252
+ "model.layers.27.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
253
+ "model.layers.27.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
254
+ "model.layers.27.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
255
+ "model.layers.27.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
256
+ "model.layers.27.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
257
+ "model.layers.27.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
258
+ "model.layers.27.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
259
+ "model.layers.28.input_layernorm.weight": "model-00003-of-00004.safetensors",
260
+ "model.layers.28.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
261
+ "model.layers.28.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
262
+ "model.layers.28.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
263
+ "model.layers.28.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
264
+ "model.layers.28.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
265
+ "model.layers.28.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
266
+ "model.layers.28.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
267
+ "model.layers.28.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
268
+ "model.layers.28.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
269
+ "model.layers.28.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
270
+ "model.layers.28.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
271
+ "model.layers.29.input_layernorm.weight": "model-00003-of-00004.safetensors",
272
+ "model.layers.29.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
273
+ "model.layers.29.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
274
+ "model.layers.29.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
275
+ "model.layers.29.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
276
+ "model.layers.29.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
277
+ "model.layers.29.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
278
+ "model.layers.29.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
279
+ "model.layers.29.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
280
+ "model.layers.29.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
281
+ "model.layers.29.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
282
+ "model.layers.29.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
283
+ "model.layers.3.input_layernorm.weight": "model-00002-of-00004.safetensors",
284
+ "model.layers.3.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
285
+ "model.layers.3.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
286
+ "model.layers.3.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
287
+ "model.layers.3.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
288
+ "model.layers.3.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
289
+ "model.layers.3.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
290
+ "model.layers.3.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
291
+ "model.layers.3.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
292
+ "model.layers.3.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
293
+ "model.layers.3.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
294
+ "model.layers.3.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
295
+ "model.layers.30.input_layernorm.weight": "model-00003-of-00004.safetensors",
296
+ "model.layers.30.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
297
+ "model.layers.30.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
298
+ "model.layers.30.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
299
+ "model.layers.30.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
300
+ "model.layers.30.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
301
+ "model.layers.30.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
302
+ "model.layers.30.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
303
+ "model.layers.30.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
304
+ "model.layers.30.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
305
+ "model.layers.30.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
306
+ "model.layers.30.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
307
+ "model.layers.31.input_layernorm.weight": "model-00003-of-00004.safetensors",
308
+ "model.layers.31.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
309
+ "model.layers.31.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
310
+ "model.layers.31.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
311
+ "model.layers.31.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
312
+ "model.layers.31.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
313
+ "model.layers.31.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
314
+ "model.layers.31.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
315
+ "model.layers.31.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
316
+ "model.layers.31.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
317
+ "model.layers.31.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
318
+ "model.layers.31.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
319
+ "model.layers.32.input_layernorm.weight": "model-00003-of-00004.safetensors",
320
+ "model.layers.32.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
321
+ "model.layers.32.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
322
+ "model.layers.32.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
323
+ "model.layers.32.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
324
+ "model.layers.32.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
325
+ "model.layers.32.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
326
+ "model.layers.32.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
327
+ "model.layers.32.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
328
+ "model.layers.32.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
329
+ "model.layers.32.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
330
+ "model.layers.32.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
331
+ "model.layers.33.input_layernorm.weight": "model-00003-of-00004.safetensors",
332
+ "model.layers.33.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
333
+ "model.layers.33.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
334
+ "model.layers.33.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
335
+ "model.layers.33.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
336
+ "model.layers.33.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
337
+ "model.layers.33.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
338
+ "model.layers.33.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
339
+ "model.layers.33.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
340
+ "model.layers.33.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
341
+ "model.layers.33.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
342
+ "model.layers.33.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
343
+ "model.layers.34.input_layernorm.weight": "model-00003-of-00004.safetensors",
344
+ "model.layers.34.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
345
+ "model.layers.34.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
346
+ "model.layers.34.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
347
+ "model.layers.34.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
348
+ "model.layers.34.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
349
+ "model.layers.34.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
350
+ "model.layers.34.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
351
+ "model.layers.34.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
352
+ "model.layers.34.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
353
+ "model.layers.34.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
354
+ "model.layers.34.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
355
+ "model.layers.35.input_layernorm.weight": "model-00004-of-00004.safetensors",
356
+ "model.layers.35.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
357
+ "model.layers.35.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
358
+ "model.layers.35.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
359
+ "model.layers.35.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
360
+ "model.layers.35.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
361
+ "model.layers.35.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
362
+ "model.layers.35.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
363
+ "model.layers.35.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
364
+ "model.layers.35.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
365
+ "model.layers.35.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
366
+ "model.layers.35.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
367
+ "model.layers.4.input_layernorm.weight": "model-00002-of-00004.safetensors",
368
+ "model.layers.4.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
369
+ "model.layers.4.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
370
+ "model.layers.4.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
371
+ "model.layers.4.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
372
+ "model.layers.4.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
373
+ "model.layers.4.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
374
+ "model.layers.4.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
375
+ "model.layers.4.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
376
+ "model.layers.4.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
377
+ "model.layers.4.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
378
+ "model.layers.4.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
379
+ "model.layers.5.input_layernorm.weight": "model-00002-of-00004.safetensors",
380
+ "model.layers.5.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
381
+ "model.layers.5.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
382
+ "model.layers.5.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
383
+ "model.layers.5.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
384
+ "model.layers.5.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
385
+ "model.layers.5.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
386
+ "model.layers.5.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
387
+ "model.layers.5.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
388
+ "model.layers.5.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
389
+ "model.layers.5.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
390
+ "model.layers.5.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
391
+ "model.layers.6.input_layernorm.weight": "model-00002-of-00004.safetensors",
392
+ "model.layers.6.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
393
+ "model.layers.6.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
394
+ "model.layers.6.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
395
+ "model.layers.6.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
396
+ "model.layers.6.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
397
+ "model.layers.6.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
398
+ "model.layers.6.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
399
+ "model.layers.6.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
400
+ "model.layers.6.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
401
+ "model.layers.6.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
402
+ "model.layers.6.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
403
+ "model.layers.7.input_layernorm.weight": "model-00002-of-00004.safetensors",
404
+ "model.layers.7.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
405
+ "model.layers.7.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
406
+ "model.layers.7.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
407
+ "model.layers.7.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
408
+ "model.layers.7.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
409
+ "model.layers.7.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
410
+ "model.layers.7.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
411
+ "model.layers.7.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
412
+ "model.layers.7.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
413
+ "model.layers.7.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
414
+ "model.layers.7.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
415
+ "model.layers.8.input_layernorm.weight": "model-00002-of-00004.safetensors",
416
+ "model.layers.8.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
417
+ "model.layers.8.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
418
+ "model.layers.8.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
419
+ "model.layers.8.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
420
+ "model.layers.8.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
421
+ "model.layers.8.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
422
+ "model.layers.8.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
423
+ "model.layers.8.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
424
+ "model.layers.8.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
425
+ "model.layers.8.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
426
+ "model.layers.8.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
427
+ "model.layers.9.input_layernorm.weight": "model-00002-of-00004.safetensors",
428
+ "model.layers.9.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
429
+ "model.layers.9.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
430
+ "model.layers.9.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
431
+ "model.layers.9.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
432
+ "model.layers.9.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
433
+ "model.layers.9.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
434
+ "model.layers.9.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
435
+ "model.layers.9.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
436
+ "model.layers.9.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
437
+ "model.layers.9.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
438
+ "model.layers.9.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
439
+ "model.norm.weight": "model-00004-of-00004.safetensors",
440
+ "visual.blocks.0.attn.proj.bias": "model-00001-of-00004.safetensors",
441
+ "visual.blocks.0.attn.proj.weight": "model-00001-of-00004.safetensors",
442
+ "visual.blocks.0.attn.qkv.bias": "model-00001-of-00004.safetensors",
443
+ "visual.blocks.0.attn.qkv.weight": "model-00001-of-00004.safetensors",
444
+ "visual.blocks.0.mlp.down_proj.bias": "model-00001-of-00004.safetensors",
445
+ "visual.blocks.0.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
446
+ "visual.blocks.0.mlp.gate_proj.bias": "model-00001-of-00004.safetensors",
447
+ "visual.blocks.0.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
448
+ "visual.blocks.0.mlp.up_proj.bias": "model-00001-of-00004.safetensors",
449
+ "visual.blocks.0.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
450
+ "visual.blocks.0.norm1.weight": "model-00001-of-00004.safetensors",
451
+ "visual.blocks.0.norm2.weight": "model-00001-of-00004.safetensors",
452
+ "visual.blocks.1.attn.proj.bias": "model-00001-of-00004.safetensors",
453
+ "visual.blocks.1.attn.proj.weight": "model-00001-of-00004.safetensors",
454
+ "visual.blocks.1.attn.qkv.bias": "model-00001-of-00004.safetensors",
455
+ "visual.blocks.1.attn.qkv.weight": "model-00001-of-00004.safetensors",
456
+ "visual.blocks.1.mlp.down_proj.bias": "model-00001-of-00004.safetensors",
457
+ "visual.blocks.1.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
458
+ "visual.blocks.1.mlp.gate_proj.bias": "model-00001-of-00004.safetensors",
459
+ "visual.blocks.1.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
460
+ "visual.blocks.1.mlp.up_proj.bias": "model-00001-of-00004.safetensors",
461
+ "visual.blocks.1.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
462
+ "visual.blocks.1.norm1.weight": "model-00001-of-00004.safetensors",
463
+ "visual.blocks.1.norm2.weight": "model-00001-of-00004.safetensors",
464
+ "visual.blocks.10.attn.proj.bias": "model-00001-of-00004.safetensors",
465
+ "visual.blocks.10.attn.proj.weight": "model-00001-of-00004.safetensors",
466
+ "visual.blocks.10.attn.qkv.bias": "model-00001-of-00004.safetensors",
467
+ "visual.blocks.10.attn.qkv.weight": "model-00001-of-00004.safetensors",
468
+ "visual.blocks.10.mlp.down_proj.bias": "model-00001-of-00004.safetensors",
469
+ "visual.blocks.10.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
470
+ "visual.blocks.10.mlp.gate_proj.bias": "model-00001-of-00004.safetensors",
471
+ "visual.blocks.10.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
472
+ "visual.blocks.10.mlp.up_proj.bias": "model-00001-of-00004.safetensors",
473
+ "visual.blocks.10.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
474
+ "visual.blocks.10.norm1.weight": "model-00001-of-00004.safetensors",
475
+ "visual.blocks.10.norm2.weight": "model-00001-of-00004.safetensors",
476
+ "visual.blocks.11.attn.proj.bias": "model-00001-of-00004.safetensors",
477
+ "visual.blocks.11.attn.proj.weight": "model-00001-of-00004.safetensors",
478
+ "visual.blocks.11.attn.qkv.bias": "model-00001-of-00004.safetensors",
479
+ "visual.blocks.11.attn.qkv.weight": "model-00001-of-00004.safetensors",
480
+ "visual.blocks.11.mlp.down_proj.bias": "model-00001-of-00004.safetensors",
481
+ "visual.blocks.11.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
482
+ "visual.blocks.11.mlp.gate_proj.bias": "model-00001-of-00004.safetensors",
483
+ "visual.blocks.11.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
484
+ "visual.blocks.11.mlp.up_proj.bias": "model-00001-of-00004.safetensors",
485
+ "visual.blocks.11.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
486
+ "visual.blocks.11.norm1.weight": "model-00001-of-00004.safetensors",
487
+ "visual.blocks.11.norm2.weight": "model-00001-of-00004.safetensors",
488
+ "visual.blocks.12.attn.proj.bias": "model-00001-of-00004.safetensors",
489
+ "visual.blocks.12.attn.proj.weight": "model-00001-of-00004.safetensors",
490
+ "visual.blocks.12.attn.qkv.bias": "model-00001-of-00004.safetensors",
491
+ "visual.blocks.12.attn.qkv.weight": "model-00001-of-00004.safetensors",
492
+ "visual.blocks.12.mlp.down_proj.bias": "model-00001-of-00004.safetensors",
493
+ "visual.blocks.12.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
494
+ "visual.blocks.12.mlp.gate_proj.bias": "model-00001-of-00004.safetensors",
495
+ "visual.blocks.12.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
496
+ "visual.blocks.12.mlp.up_proj.bias": "model-00001-of-00004.safetensors",
497
+ "visual.blocks.12.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
498
+ "visual.blocks.12.norm1.weight": "model-00001-of-00004.safetensors",
499
+ "visual.blocks.12.norm2.weight": "model-00001-of-00004.safetensors",
500
+ "visual.blocks.13.attn.proj.bias": "model-00001-of-00004.safetensors",
501
+ "visual.blocks.13.attn.proj.weight": "model-00001-of-00004.safetensors",
502
+ "visual.blocks.13.attn.qkv.bias": "model-00001-of-00004.safetensors",
503
+ "visual.blocks.13.attn.qkv.weight": "model-00001-of-00004.safetensors",
504
+ "visual.blocks.13.mlp.down_proj.bias": "model-00001-of-00004.safetensors",
505
+ "visual.blocks.13.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
506
+ "visual.blocks.13.mlp.gate_proj.bias": "model-00001-of-00004.safetensors",
507
+ "visual.blocks.13.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
508
+ "visual.blocks.13.mlp.up_proj.bias": "model-00001-of-00004.safetensors",
509
+ "visual.blocks.13.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
510
+ "visual.blocks.13.norm1.weight": "model-00001-of-00004.safetensors",
511
+ "visual.blocks.13.norm2.weight": "model-00001-of-00004.safetensors",
512
+ "visual.blocks.14.attn.proj.bias": "model-00001-of-00004.safetensors",
513
+ "visual.blocks.14.attn.proj.weight": "model-00001-of-00004.safetensors",
514
+ "visual.blocks.14.attn.qkv.bias": "model-00001-of-00004.safetensors",
515
+ "visual.blocks.14.attn.qkv.weight": "model-00001-of-00004.safetensors",
516
+ "visual.blocks.14.mlp.down_proj.bias": "model-00001-of-00004.safetensors",
517
+ "visual.blocks.14.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
518
+ "visual.blocks.14.mlp.gate_proj.bias": "model-00001-of-00004.safetensors",
519
+ "visual.blocks.14.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
520
+ "visual.blocks.14.mlp.up_proj.bias": "model-00001-of-00004.safetensors",
521
+ "visual.blocks.14.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
522
+ "visual.blocks.14.norm1.weight": "model-00001-of-00004.safetensors",
523
+ "visual.blocks.14.norm2.weight": "model-00001-of-00004.safetensors",
524
+ "visual.blocks.15.attn.proj.bias": "model-00001-of-00004.safetensors",
525
+ "visual.blocks.15.attn.proj.weight": "model-00001-of-00004.safetensors",
526
+ "visual.blocks.15.attn.qkv.bias": "model-00001-of-00004.safetensors",
527
+ "visual.blocks.15.attn.qkv.weight": "model-00001-of-00004.safetensors",
528
+ "visual.blocks.15.mlp.down_proj.bias": "model-00001-of-00004.safetensors",
529
+ "visual.blocks.15.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
530
+ "visual.blocks.15.mlp.gate_proj.bias": "model-00001-of-00004.safetensors",
531
+ "visual.blocks.15.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
532
+ "visual.blocks.15.mlp.up_proj.bias": "model-00001-of-00004.safetensors",
533
+ "visual.blocks.15.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
534
+ "visual.blocks.15.norm1.weight": "model-00001-of-00004.safetensors",
535
+ "visual.blocks.15.norm2.weight": "model-00001-of-00004.safetensors",
536
+ "visual.blocks.16.attn.proj.bias": "model-00001-of-00004.safetensors",
537
+ "visual.blocks.16.attn.proj.weight": "model-00001-of-00004.safetensors",
538
+ "visual.blocks.16.attn.qkv.bias": "model-00001-of-00004.safetensors",
539
+ "visual.blocks.16.attn.qkv.weight": "model-00001-of-00004.safetensors",
540
+ "visual.blocks.16.mlp.down_proj.bias": "model-00001-of-00004.safetensors",
541
+ "visual.blocks.16.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
542
+ "visual.blocks.16.mlp.gate_proj.bias": "model-00001-of-00004.safetensors",
543
+ "visual.blocks.16.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
544
+ "visual.blocks.16.mlp.up_proj.bias": "model-00001-of-00004.safetensors",
545
+ "visual.blocks.16.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
546
+ "visual.blocks.16.norm1.weight": "model-00001-of-00004.safetensors",
547
+ "visual.blocks.16.norm2.weight": "model-00001-of-00004.safetensors",
548
+ "visual.blocks.17.attn.proj.bias": "model-00001-of-00004.safetensors",
549
+ "visual.blocks.17.attn.proj.weight": "model-00001-of-00004.safetensors",
550
+ "visual.blocks.17.attn.qkv.bias": "model-00001-of-00004.safetensors",
551
+ "visual.blocks.17.attn.qkv.weight": "model-00001-of-00004.safetensors",
552
+ "visual.blocks.17.mlp.down_proj.bias": "model-00001-of-00004.safetensors",
553
+ "visual.blocks.17.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
554
+ "visual.blocks.17.mlp.gate_proj.bias": "model-00001-of-00004.safetensors",
555
+ "visual.blocks.17.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
556
+ "visual.blocks.17.mlp.up_proj.bias": "model-00001-of-00004.safetensors",
557
+ "visual.blocks.17.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
558
+ "visual.blocks.17.norm1.weight": "model-00001-of-00004.safetensors",
559
+ "visual.blocks.17.norm2.weight": "model-00001-of-00004.safetensors",
560
+ "visual.blocks.18.attn.proj.bias": "model-00001-of-00004.safetensors",
561
+ "visual.blocks.18.attn.proj.weight": "model-00001-of-00004.safetensors",
562
+ "visual.blocks.18.attn.qkv.bias": "model-00001-of-00004.safetensors",
563
+ "visual.blocks.18.attn.qkv.weight": "model-00001-of-00004.safetensors",
564
+ "visual.blocks.18.mlp.down_proj.bias": "model-00001-of-00004.safetensors",
565
+ "visual.blocks.18.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
566
+ "visual.blocks.18.mlp.gate_proj.bias": "model-00001-of-00004.safetensors",
567
+ "visual.blocks.18.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
568
+ "visual.blocks.18.mlp.up_proj.bias": "model-00001-of-00004.safetensors",
569
+ "visual.blocks.18.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
570
+ "visual.blocks.18.norm1.weight": "model-00001-of-00004.safetensors",
571
+ "visual.blocks.18.norm2.weight": "model-00001-of-00004.safetensors",
572
+ "visual.blocks.19.attn.proj.bias": "model-00001-of-00004.safetensors",
573
+ "visual.blocks.19.attn.proj.weight": "model-00001-of-00004.safetensors",
574
+ "visual.blocks.19.attn.qkv.bias": "model-00001-of-00004.safetensors",
575
+ "visual.blocks.19.attn.qkv.weight": "model-00001-of-00004.safetensors",
576
+ "visual.blocks.19.mlp.down_proj.bias": "model-00001-of-00004.safetensors",
577
+ "visual.blocks.19.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
578
+ "visual.blocks.19.mlp.gate_proj.bias": "model-00001-of-00004.safetensors",
579
+ "visual.blocks.19.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
580
+ "visual.blocks.19.mlp.up_proj.bias": "model-00001-of-00004.safetensors",
581
+ "visual.blocks.19.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
582
+ "visual.blocks.19.norm1.weight": "model-00001-of-00004.safetensors",
583
+ "visual.blocks.19.norm2.weight": "model-00001-of-00004.safetensors",
584
+ "visual.blocks.2.attn.proj.bias": "model-00001-of-00004.safetensors",
585
+ "visual.blocks.2.attn.proj.weight": "model-00001-of-00004.safetensors",
586
+ "visual.blocks.2.attn.qkv.bias": "model-00001-of-00004.safetensors",
587
+ "visual.blocks.2.attn.qkv.weight": "model-00001-of-00004.safetensors",
588
+ "visual.blocks.2.mlp.down_proj.bias": "model-00001-of-00004.safetensors",
589
+ "visual.blocks.2.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
590
+ "visual.blocks.2.mlp.gate_proj.bias": "model-00001-of-00004.safetensors",
591
+ "visual.blocks.2.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
592
+ "visual.blocks.2.mlp.up_proj.bias": "model-00001-of-00004.safetensors",
593
+ "visual.blocks.2.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
594
+ "visual.blocks.2.norm1.weight": "model-00001-of-00004.safetensors",
595
+ "visual.blocks.2.norm2.weight": "model-00001-of-00004.safetensors",
596
+ "visual.blocks.20.attn.proj.bias": "model-00001-of-00004.safetensors",
597
+ "visual.blocks.20.attn.proj.weight": "model-00001-of-00004.safetensors",
598
+ "visual.blocks.20.attn.qkv.bias": "model-00001-of-00004.safetensors",
599
+ "visual.blocks.20.attn.qkv.weight": "model-00001-of-00004.safetensors",
600
+ "visual.blocks.20.mlp.down_proj.bias": "model-00001-of-00004.safetensors",
601
+ "visual.blocks.20.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
602
+ "visual.blocks.20.mlp.gate_proj.bias": "model-00001-of-00004.safetensors",
603
+ "visual.blocks.20.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
604
+ "visual.blocks.20.mlp.up_proj.bias": "model-00001-of-00004.safetensors",
605
+ "visual.blocks.20.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
606
+ "visual.blocks.20.norm1.weight": "model-00001-of-00004.safetensors",
607
+ "visual.blocks.20.norm2.weight": "model-00001-of-00004.safetensors",
608
+ "visual.blocks.21.attn.proj.bias": "model-00001-of-00004.safetensors",
609
+ "visual.blocks.21.attn.proj.weight": "model-00001-of-00004.safetensors",
610
+ "visual.blocks.21.attn.qkv.bias": "model-00001-of-00004.safetensors",
611
+ "visual.blocks.21.attn.qkv.weight": "model-00001-of-00004.safetensors",
612
+ "visual.blocks.21.mlp.down_proj.bias": "model-00001-of-00004.safetensors",
613
+ "visual.blocks.21.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
614
+ "visual.blocks.21.mlp.gate_proj.bias": "model-00001-of-00004.safetensors",
615
+ "visual.blocks.21.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
616
+ "visual.blocks.21.mlp.up_proj.bias": "model-00001-of-00004.safetensors",
617
+ "visual.blocks.21.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
618
+ "visual.blocks.21.norm1.weight": "model-00001-of-00004.safetensors",
619
+ "visual.blocks.21.norm2.weight": "model-00001-of-00004.safetensors",
620
+ "visual.blocks.22.attn.proj.bias": "model-00001-of-00004.safetensors",
621
+ "visual.blocks.22.attn.proj.weight": "model-00001-of-00004.safetensors",
622
+ "visual.blocks.22.attn.qkv.bias": "model-00001-of-00004.safetensors",
623
+ "visual.blocks.22.attn.qkv.weight": "model-00001-of-00004.safetensors",
624
+ "visual.blocks.22.mlp.down_proj.bias": "model-00001-of-00004.safetensors",
625
+ "visual.blocks.22.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
626
+ "visual.blocks.22.mlp.gate_proj.bias": "model-00001-of-00004.safetensors",
627
+ "visual.blocks.22.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
628
+ "visual.blocks.22.mlp.up_proj.bias": "model-00001-of-00004.safetensors",
629
+ "visual.blocks.22.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
630
+ "visual.blocks.22.norm1.weight": "model-00001-of-00004.safetensors",
631
+ "visual.blocks.22.norm2.weight": "model-00001-of-00004.safetensors",
632
+ "visual.blocks.23.attn.proj.bias": "model-00001-of-00004.safetensors",
633
+ "visual.blocks.23.attn.proj.weight": "model-00001-of-00004.safetensors",
634
+ "visual.blocks.23.attn.qkv.bias": "model-00001-of-00004.safetensors",
635
+ "visual.blocks.23.attn.qkv.weight": "model-00001-of-00004.safetensors",
636
+ "visual.blocks.23.mlp.down_proj.bias": "model-00001-of-00004.safetensors",
637
+ "visual.blocks.23.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
638
+ "visual.blocks.23.mlp.gate_proj.bias": "model-00001-of-00004.safetensors",
639
+ "visual.blocks.23.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
640
+ "visual.blocks.23.mlp.up_proj.bias": "model-00001-of-00004.safetensors",
641
+ "visual.blocks.23.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
642
+ "visual.blocks.23.norm1.weight": "model-00001-of-00004.safetensors",
643
+ "visual.blocks.23.norm2.weight": "model-00001-of-00004.safetensors",
644
+ "visual.blocks.24.attn.proj.bias": "model-00001-of-00004.safetensors",
645
+ "visual.blocks.24.attn.proj.weight": "model-00001-of-00004.safetensors",
646
+ "visual.blocks.24.attn.qkv.bias": "model-00001-of-00004.safetensors",
647
+ "visual.blocks.24.attn.qkv.weight": "model-00001-of-00004.safetensors",
648
+ "visual.blocks.24.mlp.down_proj.bias": "model-00001-of-00004.safetensors",
649
+ "visual.blocks.24.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
650
+ "visual.blocks.24.mlp.gate_proj.bias": "model-00001-of-00004.safetensors",
651
+ "visual.blocks.24.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
652
+ "visual.blocks.24.mlp.up_proj.bias": "model-00001-of-00004.safetensors",
653
+ "visual.blocks.24.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
654
+ "visual.blocks.24.norm1.weight": "model-00001-of-00004.safetensors",
655
+ "visual.blocks.24.norm2.weight": "model-00001-of-00004.safetensors",
656
+ "visual.blocks.25.attn.proj.bias": "model-00001-of-00004.safetensors",
657
+ "visual.blocks.25.attn.proj.weight": "model-00001-of-00004.safetensors",
658
+ "visual.blocks.25.attn.qkv.bias": "model-00001-of-00004.safetensors",
659
+ "visual.blocks.25.attn.qkv.weight": "model-00001-of-00004.safetensors",
660
+ "visual.blocks.25.mlp.down_proj.bias": "model-00001-of-00004.safetensors",
661
+ "visual.blocks.25.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
662
+ "visual.blocks.25.mlp.gate_proj.bias": "model-00001-of-00004.safetensors",
663
+ "visual.blocks.25.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
664
+ "visual.blocks.25.mlp.up_proj.bias": "model-00001-of-00004.safetensors",
665
+ "visual.blocks.25.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
666
+ "visual.blocks.25.norm1.weight": "model-00001-of-00004.safetensors",
667
+ "visual.blocks.25.norm2.weight": "model-00001-of-00004.safetensors",
668
+ "visual.blocks.26.attn.proj.bias": "model-00001-of-00004.safetensors",
669
+ "visual.blocks.26.attn.proj.weight": "model-00001-of-00004.safetensors",
670
+ "visual.blocks.26.attn.qkv.bias": "model-00001-of-00004.safetensors",
671
+ "visual.blocks.26.attn.qkv.weight": "model-00001-of-00004.safetensors",
672
+ "visual.blocks.26.mlp.down_proj.bias": "model-00001-of-00004.safetensors",
673
+ "visual.blocks.26.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
674
+ "visual.blocks.26.mlp.gate_proj.bias": "model-00001-of-00004.safetensors",
675
+ "visual.blocks.26.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
676
+ "visual.blocks.26.mlp.up_proj.bias": "model-00001-of-00004.safetensors",
677
+ "visual.blocks.26.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
678
+ "visual.blocks.26.norm1.weight": "model-00001-of-00004.safetensors",
679
+ "visual.blocks.26.norm2.weight": "model-00001-of-00004.safetensors",
680
+ "visual.blocks.27.attn.proj.bias": "model-00001-of-00004.safetensors",
681
+ "visual.blocks.27.attn.proj.weight": "model-00001-of-00004.safetensors",
682
+ "visual.blocks.27.attn.qkv.bias": "model-00001-of-00004.safetensors",
683
+ "visual.blocks.27.attn.qkv.weight": "model-00001-of-00004.safetensors",
684
+ "visual.blocks.27.mlp.down_proj.bias": "model-00001-of-00004.safetensors",
685
+ "visual.blocks.27.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
686
+ "visual.blocks.27.mlp.gate_proj.bias": "model-00001-of-00004.safetensors",
687
+ "visual.blocks.27.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
688
+ "visual.blocks.27.mlp.up_proj.bias": "model-00001-of-00004.safetensors",
689
+ "visual.blocks.27.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
690
+ "visual.blocks.27.norm1.weight": "model-00001-of-00004.safetensors",
691
+ "visual.blocks.27.norm2.weight": "model-00001-of-00004.safetensors",
692
+ "visual.blocks.28.attn.proj.bias": "model-00001-of-00004.safetensors",
693
+ "visual.blocks.28.attn.proj.weight": "model-00001-of-00004.safetensors",
694
+ "visual.blocks.28.attn.qkv.bias": "model-00001-of-00004.safetensors",
695
+ "visual.blocks.28.attn.qkv.weight": "model-00001-of-00004.safetensors",
696
+ "visual.blocks.28.mlp.down_proj.bias": "model-00001-of-00004.safetensors",
697
+ "visual.blocks.28.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
698
+ "visual.blocks.28.mlp.gate_proj.bias": "model-00001-of-00004.safetensors",
699
+ "visual.blocks.28.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
700
+ "visual.blocks.28.mlp.up_proj.bias": "model-00001-of-00004.safetensors",
701
+ "visual.blocks.28.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
702
+ "visual.blocks.28.norm1.weight": "model-00001-of-00004.safetensors",
703
+ "visual.blocks.28.norm2.weight": "model-00001-of-00004.safetensors",
704
+ "visual.blocks.29.attn.proj.bias": "model-00001-of-00004.safetensors",
705
+ "visual.blocks.29.attn.proj.weight": "model-00001-of-00004.safetensors",
706
+ "visual.blocks.29.attn.qkv.bias": "model-00001-of-00004.safetensors",
707
+ "visual.blocks.29.attn.qkv.weight": "model-00001-of-00004.safetensors",
708
+ "visual.blocks.29.mlp.down_proj.bias": "model-00001-of-00004.safetensors",
709
+ "visual.blocks.29.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
710
+ "visual.blocks.29.mlp.gate_proj.bias": "model-00001-of-00004.safetensors",
711
+ "visual.blocks.29.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
712
+ "visual.blocks.29.mlp.up_proj.bias": "model-00001-of-00004.safetensors",
713
+ "visual.blocks.29.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
714
+ "visual.blocks.29.norm1.weight": "model-00001-of-00004.safetensors",
715
+ "visual.blocks.29.norm2.weight": "model-00001-of-00004.safetensors",
716
+ "visual.blocks.3.attn.proj.bias": "model-00001-of-00004.safetensors",
717
+ "visual.blocks.3.attn.proj.weight": "model-00001-of-00004.safetensors",
718
+ "visual.blocks.3.attn.qkv.bias": "model-00001-of-00004.safetensors",
719
+ "visual.blocks.3.attn.qkv.weight": "model-00001-of-00004.safetensors",
720
+ "visual.blocks.3.mlp.down_proj.bias": "model-00001-of-00004.safetensors",
721
+ "visual.blocks.3.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
722
+ "visual.blocks.3.mlp.gate_proj.bias": "model-00001-of-00004.safetensors",
723
+ "visual.blocks.3.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
724
+ "visual.blocks.3.mlp.up_proj.bias": "model-00001-of-00004.safetensors",
725
+ "visual.blocks.3.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
726
+ "visual.blocks.3.norm1.weight": "model-00001-of-00004.safetensors",
727
+ "visual.blocks.3.norm2.weight": "model-00001-of-00004.safetensors",
728
+ "visual.blocks.30.attn.proj.bias": "model-00001-of-00004.safetensors",
729
+ "visual.blocks.30.attn.proj.weight": "model-00001-of-00004.safetensors",
730
+ "visual.blocks.30.attn.qkv.bias": "model-00001-of-00004.safetensors",
731
+ "visual.blocks.30.attn.qkv.weight": "model-00001-of-00004.safetensors",
732
+ "visual.blocks.30.mlp.down_proj.bias": "model-00001-of-00004.safetensors",
733
+ "visual.blocks.30.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
734
+ "visual.blocks.30.mlp.gate_proj.bias": "model-00001-of-00004.safetensors",
735
+ "visual.blocks.30.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
736
+ "visual.blocks.30.mlp.up_proj.bias": "model-00001-of-00004.safetensors",
737
+ "visual.blocks.30.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
738
+ "visual.blocks.30.norm1.weight": "model-00001-of-00004.safetensors",
739
+ "visual.blocks.30.norm2.weight": "model-00001-of-00004.safetensors",
740
+ "visual.blocks.31.attn.proj.bias": "model-00001-of-00004.safetensors",
741
+ "visual.blocks.31.attn.proj.weight": "model-00001-of-00004.safetensors",
742
+ "visual.blocks.31.attn.qkv.bias": "model-00001-of-00004.safetensors",
743
+ "visual.blocks.31.attn.qkv.weight": "model-00001-of-00004.safetensors",
744
+ "visual.blocks.31.mlp.down_proj.bias": "model-00001-of-00004.safetensors",
745
+ "visual.blocks.31.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
746
+ "visual.blocks.31.mlp.gate_proj.bias": "model-00001-of-00004.safetensors",
747
+ "visual.blocks.31.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
748
+ "visual.blocks.31.mlp.up_proj.bias": "model-00001-of-00004.safetensors",
749
+ "visual.blocks.31.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
750
+ "visual.blocks.31.norm1.weight": "model-00001-of-00004.safetensors",
751
+ "visual.blocks.31.norm2.weight": "model-00001-of-00004.safetensors",
752
+ "visual.blocks.4.attn.proj.bias": "model-00001-of-00004.safetensors",
753
+ "visual.blocks.4.attn.proj.weight": "model-00001-of-00004.safetensors",
754
+ "visual.blocks.4.attn.qkv.bias": "model-00001-of-00004.safetensors",
755
+ "visual.blocks.4.attn.qkv.weight": "model-00001-of-00004.safetensors",
756
+ "visual.blocks.4.mlp.down_proj.bias": "model-00001-of-00004.safetensors",
757
+ "visual.blocks.4.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
758
+ "visual.blocks.4.mlp.gate_proj.bias": "model-00001-of-00004.safetensors",
759
+ "visual.blocks.4.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
760
+ "visual.blocks.4.mlp.up_proj.bias": "model-00001-of-00004.safetensors",
761
+ "visual.blocks.4.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
762
+ "visual.blocks.4.norm1.weight": "model-00001-of-00004.safetensors",
763
+ "visual.blocks.4.norm2.weight": "model-00001-of-00004.safetensors",
764
+ "visual.blocks.5.attn.proj.bias": "model-00001-of-00004.safetensors",
765
+ "visual.blocks.5.attn.proj.weight": "model-00001-of-00004.safetensors",
766
+ "visual.blocks.5.attn.qkv.bias": "model-00001-of-00004.safetensors",
767
+ "visual.blocks.5.attn.qkv.weight": "model-00001-of-00004.safetensors",
768
+ "visual.blocks.5.mlp.down_proj.bias": "model-00001-of-00004.safetensors",
769
+ "visual.blocks.5.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
770
+ "visual.blocks.5.mlp.gate_proj.bias": "model-00001-of-00004.safetensors",
771
+ "visual.blocks.5.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
772
+ "visual.blocks.5.mlp.up_proj.bias": "model-00001-of-00004.safetensors",
773
+ "visual.blocks.5.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
774
+ "visual.blocks.5.norm1.weight": "model-00001-of-00004.safetensors",
775
+ "visual.blocks.5.norm2.weight": "model-00001-of-00004.safetensors",
776
+ "visual.blocks.6.attn.proj.bias": "model-00001-of-00004.safetensors",
777
+ "visual.blocks.6.attn.proj.weight": "model-00001-of-00004.safetensors",
778
+ "visual.blocks.6.attn.qkv.bias": "model-00001-of-00004.safetensors",
779
+ "visual.blocks.6.attn.qkv.weight": "model-00001-of-00004.safetensors",
780
+ "visual.blocks.6.mlp.down_proj.bias": "model-00001-of-00004.safetensors",
781
+ "visual.blocks.6.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
782
+ "visual.blocks.6.mlp.gate_proj.bias": "model-00001-of-00004.safetensors",
783
+ "visual.blocks.6.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
784
+ "visual.blocks.6.mlp.up_proj.bias": "model-00001-of-00004.safetensors",
785
+ "visual.blocks.6.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
786
+ "visual.blocks.6.norm1.weight": "model-00001-of-00004.safetensors",
787
+ "visual.blocks.6.norm2.weight": "model-00001-of-00004.safetensors",
788
+ "visual.blocks.7.attn.proj.bias": "model-00001-of-00004.safetensors",
789
+ "visual.blocks.7.attn.proj.weight": "model-00001-of-00004.safetensors",
790
+ "visual.blocks.7.attn.qkv.bias": "model-00001-of-00004.safetensors",
791
+ "visual.blocks.7.attn.qkv.weight": "model-00001-of-00004.safetensors",
792
+ "visual.blocks.7.mlp.down_proj.bias": "model-00001-of-00004.safetensors",
793
+ "visual.blocks.7.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
794
+ "visual.blocks.7.mlp.gate_proj.bias": "model-00001-of-00004.safetensors",
795
+ "visual.blocks.7.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
796
+ "visual.blocks.7.mlp.up_proj.bias": "model-00001-of-00004.safetensors",
797
+ "visual.blocks.7.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
798
+ "visual.blocks.7.norm1.weight": "model-00001-of-00004.safetensors",
799
+ "visual.blocks.7.norm2.weight": "model-00001-of-00004.safetensors",
800
+ "visual.blocks.8.attn.proj.bias": "model-00001-of-00004.safetensors",
801
+ "visual.blocks.8.attn.proj.weight": "model-00001-of-00004.safetensors",
802
+ "visual.blocks.8.attn.qkv.bias": "model-00001-of-00004.safetensors",
803
+ "visual.blocks.8.attn.qkv.weight": "model-00001-of-00004.safetensors",
804
+ "visual.blocks.8.mlp.down_proj.bias": "model-00001-of-00004.safetensors",
805
+ "visual.blocks.8.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
806
+ "visual.blocks.8.mlp.gate_proj.bias": "model-00001-of-00004.safetensors",
807
+ "visual.blocks.8.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
808
+ "visual.blocks.8.mlp.up_proj.bias": "model-00001-of-00004.safetensors",
809
+ "visual.blocks.8.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
810
+ "visual.blocks.8.norm1.weight": "model-00001-of-00004.safetensors",
811
+ "visual.blocks.8.norm2.weight": "model-00001-of-00004.safetensors",
812
+ "visual.blocks.9.attn.proj.bias": "model-00001-of-00004.safetensors",
813
+ "visual.blocks.9.attn.proj.weight": "model-00001-of-00004.safetensors",
814
+ "visual.blocks.9.attn.qkv.bias": "model-00001-of-00004.safetensors",
815
+ "visual.blocks.9.attn.qkv.weight": "model-00001-of-00004.safetensors",
816
+ "visual.blocks.9.mlp.down_proj.bias": "model-00001-of-00004.safetensors",
817
+ "visual.blocks.9.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
818
+ "visual.blocks.9.mlp.gate_proj.bias": "model-00001-of-00004.safetensors",
819
+ "visual.blocks.9.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
820
+ "visual.blocks.9.mlp.up_proj.bias": "model-00001-of-00004.safetensors",
821
+ "visual.blocks.9.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
822
+ "visual.blocks.9.norm1.weight": "model-00001-of-00004.safetensors",
823
+ "visual.blocks.9.norm2.weight": "model-00001-of-00004.safetensors",
824
+ "visual.merger.ln_q.weight": "model-00001-of-00004.safetensors",
825
+ "visual.merger.mlp.0.bias": "model-00001-of-00004.safetensors",
826
+ "visual.merger.mlp.0.weight": "model-00001-of-00004.safetensors",
827
+ "visual.merger.mlp.2.bias": "model-00001-of-00004.safetensors",
828
+ "visual.merger.mlp.2.weight": "model-00001-of-00004.safetensors",
829
+ "visual.patch_embed.proj.weight": "model-00001-of-00004.safetensors"
830
+ }
831
+ }
model_index.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "OmniGen2Pipeline",
3
+ "_diffusers_version": "0.33.1",
4
+ "mllm": [
5
+ "transformers",
6
+ "Qwen2_5_VLForConditionalGeneration"
7
+ ],
8
+ "processor": [
9
+ "transformers",
10
+ "Qwen2_5_VLProcessor"
11
+ ],
12
+ "scheduler": [
13
+ "scheduling_flow_match_euler_discrete",
14
+ "FlowMatchEulerDiscreteScheduler"
15
+ ],
16
+ "transformer": [
17
+ "transformer_omnigen2",
18
+ "OmniGen2Transformer2DModel"
19
+ ],
20
+ "vae": [
21
+ "diffusers",
22
+ "AutoencoderKL"
23
+ ]
24
+ }
processor/added_tokens.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</tool_call>": 151658,
3
+ "<tool_call>": 151657,
4
+ "<|box_end|>": 151649,
5
+ "<|box_start|>": 151648,
6
+ "<|endofimg|>": 151666,
7
+ "<|endofmeta|>": 151668,
8
+ "<|endoftext|>": 151643,
9
+ "<|file_sep|>": 151664,
10
+ "<|fim_middle|>": 151660,
11
+ "<|fim_pad|>": 151662,
12
+ "<|fim_prefix|>": 151659,
13
+ "<|fim_suffix|>": 151661,
14
+ "<|im_end|>": 151645,
15
+ "<|im_start|>": 151644,
16
+ "<|image_pad|>": 151655,
17
+ "<|img|>": 151665,
18
+ "<|meta|>": 151667,
19
+ "<|object_ref_end|>": 151647,
20
+ "<|object_ref_start|>": 151646,
21
+ "<|quad_end|>": 151651,
22
+ "<|quad_start|>": 151650,
23
+ "<|repo_name|>": 151663,
24
+ "<|video_pad|>": 151656,
25
+ "<|vision_end|>": 151653,
26
+ "<|vision_pad|>": 151654,
27
+ "<|vision_start|>": 151652
28
+ }
processor/chat_template.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "chat_template": "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
3
+ }
processor/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
processor/preprocessor_config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_convert_rgb": true,
3
+ "do_normalize": true,
4
+ "do_rescale": true,
5
+ "do_resize": true,
6
+ "image_mean": [
7
+ 0.48145466,
8
+ 0.4578275,
9
+ 0.40821073
10
+ ],
11
+ "image_processor_type": "Qwen2VLImageProcessor",
12
+ "image_std": [
13
+ 0.26862954,
14
+ 0.26130258,
15
+ 0.27577711
16
+ ],
17
+ "max_pixels": 12845056,
18
+ "merge_size": 2,
19
+ "min_pixels": 3136,
20
+ "patch_size": 14,
21
+ "processor_class": "Qwen2_5_VLProcessor",
22
+ "resample": 3,
23
+ "rescale_factor": 0.00392156862745098,
24
+ "size": {
25
+ "longest_edge": 12845056,
26
+ "shortest_edge": 3136
27
+ },
28
+ "temporal_patch_size": 2
29
+ }
processor/special_tokens_map.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|object_ref_start|>",
6
+ "<|object_ref_end|>",
7
+ "<|box_start|>",
8
+ "<|box_end|>",
9
+ "<|quad_start|>",
10
+ "<|quad_end|>",
11
+ "<|vision_start|>",
12
+ "<|vision_end|>",
13
+ "<|vision_pad|>",
14
+ "<|image_pad|>",
15
+ "<|video_pad|>"
16
+ ],
17
+ "eos_token": {
18
+ "content": "<|im_end|>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ },
24
+ "pad_token": {
25
+ "content": "<|endoftext|>",
26
+ "lstrip": false,
27
+ "normalized": false,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ }
31
+ }
processor/tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:69f12bcc978e3d112e092478b218f0161ef3d4bec08792866c99d29830772f08
3
+ size 11422644
processor/tokenizer_config.json ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|vision_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|vision_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|vision_pad|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|video_pad|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|fim_pad|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ },
181
+ "151665": {
182
+ "content": "<|img|>",
183
+ "lstrip": false,
184
+ "normalized": false,
185
+ "rstrip": false,
186
+ "single_word": false,
187
+ "special": true
188
+ },
189
+ "151666": {
190
+ "content": "<|endofimg|>",
191
+ "lstrip": false,
192
+ "normalized": false,
193
+ "rstrip": false,
194
+ "single_word": false,
195
+ "special": true
196
+ },
197
+ "151667": {
198
+ "content": "<|meta|>",
199
+ "lstrip": false,
200
+ "normalized": false,
201
+ "rstrip": false,
202
+ "single_word": false,
203
+ "special": true
204
+ },
205
+ "151668": {
206
+ "content": "<|endofmeta|>",
207
+ "lstrip": false,
208
+ "normalized": false,
209
+ "rstrip": false,
210
+ "single_word": false,
211
+ "special": true
212
+ }
213
+ },
214
+ "additional_special_tokens": [
215
+ "<|im_start|>",
216
+ "<|im_end|>",
217
+ "<|object_ref_start|>",
218
+ "<|object_ref_end|>",
219
+ "<|box_start|>",
220
+ "<|box_end|>",
221
+ "<|quad_start|>",
222
+ "<|quad_end|>",
223
+ "<|vision_start|>",
224
+ "<|vision_end|>",
225
+ "<|vision_pad|>",
226
+ "<|image_pad|>",
227
+ "<|video_pad|>"
228
+ ],
229
+ "bos_token": null,
230
+ "chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n",
231
+ "clean_up_tokenization_spaces": false,
232
+ "eos_token": "<|im_end|>",
233
+ "errors": "replace",
234
+ "extra_special_tokens": {},
235
+ "model_max_length": 131072,
236
+ "pad_token": "<|endoftext|>",
237
+ "processor_class": "Qwen2_5_VLProcessor",
238
+ "split_special_tokens": false,
239
+ "tokenizer_class": "Qwen2Tokenizer",
240
+ "unk_token": null
241
+ }
processor/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
scheduler/scheduler_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "FlowMatchEulerDiscreteScheduler",
3
+ "_diffusers_version": "0.33.1",
4
+ "dynamic_time_shift": true,
5
+ "num_train_timesteps": 1000
6
+ }
scheduler/scheduling_flow_match_euler_discrete.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from dataclasses import dataclass
17
+ from typing import List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import torch
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.utils import BaseOutput, logging
24
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
25
+
26
+
27
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
28
+
29
+
30
+ @dataclass
31
+ class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
32
+ """
33
+ Output class for the scheduler's `step` function output.
34
+
35
+ Args:
36
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
37
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
38
+ denoising loop.
39
+ """
40
+
41
+ prev_sample: torch.FloatTensor
42
+
43
+
44
+ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
45
+ """
46
+ Euler scheduler.
47
+
48
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
49
+ methods the library implements for all schedulers such as loading and saving.
50
+
51
+ Args:
52
+ num_train_timesteps (`int`, defaults to 1000):
53
+ The number of diffusion steps to train the model.
54
+ timestep_spacing (`str`, defaults to `"linspace"`):
55
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
56
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
57
+ shift (`float`, defaults to 1.0):
58
+ The shift value for the timestep schedule.
59
+ """
60
+
61
+ _compatibles = []
62
+ order = 1
63
+
64
+ @register_to_config
65
+ def __init__(
66
+ self,
67
+ num_train_timesteps: int = 1000,
68
+ dynamic_time_shift: bool = False
69
+ ):
70
+ timesteps = torch.linspace(0, 1, num_train_timesteps + 1, dtype=torch.float32)[:-1]
71
+
72
+ self.timesteps = timesteps
73
+
74
+ self._step_index = None
75
+ self._begin_index = None
76
+
77
+ @property
78
+ def step_index(self):
79
+ """
80
+ The index counter for current timestep. It will increase 1 after each scheduler step.
81
+ """
82
+ return self._step_index
83
+
84
+ @property
85
+ def begin_index(self):
86
+ """
87
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
88
+ """
89
+ return self._begin_index
90
+
91
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
92
+ def set_begin_index(self, begin_index: int = 0):
93
+ """
94
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
95
+
96
+ Args:
97
+ begin_index (`int`):
98
+ The begin index for the scheduler.
99
+ """
100
+ self._begin_index = begin_index
101
+
102
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
103
+ if schedule_timesteps is None:
104
+ schedule_timesteps = self._timesteps
105
+
106
+ indices = (schedule_timesteps == timestep).nonzero()
107
+
108
+ # The sigma index that is taken for the **very** first `step`
109
+ # is always the second index (or the last index if there is only 1)
110
+ # This way we can ensure we don't accidentally skip a sigma in
111
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
112
+ pos = 1 if len(indices) > 1 else 0
113
+
114
+ return indices[pos].item()
115
+
116
+ # def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
117
+ # return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
118
+
119
+ def set_timesteps(
120
+ self,
121
+ num_inference_steps: int = None,
122
+ device: Union[str, torch.device] = None,
123
+ timesteps: Optional[List[float]] = None,
124
+ num_tokens: Optional[int] = None
125
+ ):
126
+ """
127
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
128
+
129
+ Args:
130
+ num_inference_steps (`int`):
131
+ The number of diffusion steps used when generating samples with a pre-trained model.
132
+ device (`str` or `torch.device`, *optional*):
133
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
134
+ """
135
+
136
+ if timesteps is None:
137
+ self.num_inference_steps = num_inference_steps
138
+ timesteps = np.linspace(0, 1, num_inference_steps + 1, dtype=np.float32)[:-1]
139
+ if self.config.dynamic_time_shift and num_tokens is not None:
140
+ m = np.sqrt(num_tokens) / 40 # when input resolution is 320 * 320, m = 1, when input resolution is 1024 * 1024, m = 3.2
141
+ timesteps = timesteps / (m - m * timesteps + timesteps)
142
+
143
+ timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device)
144
+ _timesteps = torch.cat([timesteps, torch.ones(1, device=timesteps.device)])
145
+
146
+ self.timesteps = timesteps
147
+ self._timesteps = _timesteps
148
+ self._step_index = None
149
+ self._begin_index = None
150
+
151
+ def _init_step_index(self, timestep):
152
+ if self.begin_index is None:
153
+ if isinstance(timestep, torch.Tensor):
154
+ timestep = timestep.to(self.timesteps.device)
155
+ self._step_index = self.index_for_timestep(timestep)
156
+ else:
157
+ self._step_index = self._begin_index
158
+
159
+ def step(
160
+ self,
161
+ model_output: torch.FloatTensor,
162
+ timestep: Union[float, torch.FloatTensor],
163
+ sample: torch.FloatTensor,
164
+ generator: Optional[torch.Generator] = None,
165
+ return_dict: bool = True,
166
+ ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
167
+ """
168
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
169
+ process from the learned model outputs (most often the predicted noise).
170
+
171
+ Args:
172
+ model_output (`torch.FloatTensor`):
173
+ The direct output from learned diffusion model.
174
+ timestep (`float`):
175
+ The current discrete timestep in the diffusion chain.
176
+ sample (`torch.FloatTensor`):
177
+ A current instance of a sample created by the diffusion process.
178
+ s_churn (`float`):
179
+ s_tmin (`float`):
180
+ s_tmax (`float`):
181
+ s_noise (`float`, defaults to 1.0):
182
+ Scaling factor for noise added to the sample.
183
+ generator (`torch.Generator`, *optional*):
184
+ A random number generator.
185
+ return_dict (`bool`):
186
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
187
+ tuple.
188
+
189
+ Returns:
190
+ [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
191
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
192
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
193
+ """
194
+
195
+ if (
196
+ isinstance(timestep, int)
197
+ or isinstance(timestep, torch.IntTensor)
198
+ or isinstance(timestep, torch.LongTensor)
199
+ ):
200
+ raise ValueError(
201
+ (
202
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
203
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
204
+ " one of the `scheduler.timesteps` as a timestep."
205
+ ),
206
+ )
207
+
208
+ if self.step_index is None:
209
+ self._init_step_index(timestep)
210
+ # Upcast to avoid precision issues when computing prev_sample
211
+ sample = sample.to(torch.float32)
212
+ t = self._timesteps[self.step_index]
213
+ t_next = self._timesteps[self.step_index + 1]
214
+
215
+ prev_sample = sample + (t_next - t) * model_output
216
+
217
+ # Cast sample back to model compatible dtype
218
+ prev_sample = prev_sample.to(model_output.dtype)
219
+
220
+ # upon completion increase step index by one
221
+ self._step_index += 1
222
+
223
+ if not return_dict:
224
+ return (prev_sample,)
225
+
226
+ return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
227
+
228
+ def __len__(self):
229
+ return self.config.num_train_timesteps
transformer/config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "OmniGen2Transformer2DModel",
3
+ "_diffusers_version": "0.33.1",
4
+ "axes_dim_rope": [
5
+ 40,
6
+ 40,
7
+ 40
8
+ ],
9
+ "axes_lens": [
10
+ 1024,
11
+ 1664,
12
+ 1664
13
+ ],
14
+ "ffn_dim_multiplier": null,
15
+ "hidden_size": 2520,
16
+ "in_channels": 16,
17
+ "multiple_of": 256,
18
+ "norm_eps": 1e-05,
19
+ "num_attention_heads": 21,
20
+ "num_kv_heads": 7,
21
+ "num_layers": 32,
22
+ "num_refiner_layers": 2,
23
+ "out_channels": null,
24
+ "patch_size": 2,
25
+ "text_feat_dim": 2048,
26
+ "timestep_scale": 1000.0
27
+ }
transformer/diffusion_pytorch_model-00001-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4d167d318d9c34833d1522dffb3fdbc424fd4e482eda35faf6b24306317052ba
3
+ size 9995005752
transformer/diffusion_pytorch_model-00002-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:844bd3f41ce0879c9f968934e7c7cbbc58d57e46aae2661c23ac9f6478ebd9ba
3
+ size 5873701280
transformer/diffusion_pytorch_model.safetensors.index.json ADDED
@@ -0,0 +1,589 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 15868645600
4
+ },
5
+ "weight_map": {
6
+ "context_refiner.0.attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
7
+ "context_refiner.0.attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
8
+ "context_refiner.0.attn.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
9
+ "context_refiner.0.attn.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
10
+ "context_refiner.0.attn.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
11
+ "context_refiner.0.attn.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
12
+ "context_refiner.0.feed_forward.linear_1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
13
+ "context_refiner.0.feed_forward.linear_2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
14
+ "context_refiner.0.feed_forward.linear_3.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
15
+ "context_refiner.0.ffn_norm1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
16
+ "context_refiner.0.ffn_norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
17
+ "context_refiner.0.norm1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
18
+ "context_refiner.0.norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
19
+ "context_refiner.1.attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
20
+ "context_refiner.1.attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
21
+ "context_refiner.1.attn.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
22
+ "context_refiner.1.attn.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
23
+ "context_refiner.1.attn.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
24
+ "context_refiner.1.attn.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
25
+ "context_refiner.1.feed_forward.linear_1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
26
+ "context_refiner.1.feed_forward.linear_2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
27
+ "context_refiner.1.feed_forward.linear_3.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
28
+ "context_refiner.1.ffn_norm1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
29
+ "context_refiner.1.ffn_norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
30
+ "context_refiner.1.norm1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
31
+ "context_refiner.1.norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
32
+ "image_index_embedding": "diffusion_pytorch_model-00001-of-00002.safetensors",
33
+ "layers.0.attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
34
+ "layers.0.attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
35
+ "layers.0.attn.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
36
+ "layers.0.attn.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
37
+ "layers.0.attn.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
38
+ "layers.0.attn.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
39
+ "layers.0.feed_forward.linear_1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
40
+ "layers.0.feed_forward.linear_2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
41
+ "layers.0.feed_forward.linear_3.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
42
+ "layers.0.ffn_norm1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
43
+ "layers.0.ffn_norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
44
+ "layers.0.norm1.linear.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
45
+ "layers.0.norm1.linear.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
46
+ "layers.0.norm1.norm.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
47
+ "layers.0.norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
48
+ "layers.1.attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
49
+ "layers.1.attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
50
+ "layers.1.attn.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
51
+ "layers.1.attn.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
52
+ "layers.1.attn.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
53
+ "layers.1.attn.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
54
+ "layers.1.feed_forward.linear_1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
55
+ "layers.1.feed_forward.linear_2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
56
+ "layers.1.feed_forward.linear_3.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
57
+ "layers.1.ffn_norm1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
58
+ "layers.1.ffn_norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
59
+ "layers.1.norm1.linear.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
60
+ "layers.1.norm1.linear.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
61
+ "layers.1.norm1.norm.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
62
+ "layers.1.norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
63
+ "layers.10.attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
64
+ "layers.10.attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
65
+ "layers.10.attn.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
66
+ "layers.10.attn.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
67
+ "layers.10.attn.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
68
+ "layers.10.attn.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
69
+ "layers.10.feed_forward.linear_1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
70
+ "layers.10.feed_forward.linear_2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
71
+ "layers.10.feed_forward.linear_3.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
72
+ "layers.10.ffn_norm1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
73
+ "layers.10.ffn_norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
74
+ "layers.10.norm1.linear.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
75
+ "layers.10.norm1.linear.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
76
+ "layers.10.norm1.norm.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
77
+ "layers.10.norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
78
+ "layers.11.attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
79
+ "layers.11.attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
80
+ "layers.11.attn.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
81
+ "layers.11.attn.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
82
+ "layers.11.attn.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
83
+ "layers.11.attn.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
84
+ "layers.11.feed_forward.linear_1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
85
+ "layers.11.feed_forward.linear_2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
86
+ "layers.11.feed_forward.linear_3.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
87
+ "layers.11.ffn_norm1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
88
+ "layers.11.ffn_norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
89
+ "layers.11.norm1.linear.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
90
+ "layers.11.norm1.linear.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
91
+ "layers.11.norm1.norm.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
92
+ "layers.11.norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
93
+ "layers.12.attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
94
+ "layers.12.attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
95
+ "layers.12.attn.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
96
+ "layers.12.attn.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
97
+ "layers.12.attn.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
98
+ "layers.12.attn.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
99
+ "layers.12.feed_forward.linear_1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
100
+ "layers.12.feed_forward.linear_2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
101
+ "layers.12.feed_forward.linear_3.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
102
+ "layers.12.ffn_norm1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
103
+ "layers.12.ffn_norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
104
+ "layers.12.norm1.linear.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
105
+ "layers.12.norm1.linear.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
106
+ "layers.12.norm1.norm.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
107
+ "layers.12.norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
108
+ "layers.13.attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
109
+ "layers.13.attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
110
+ "layers.13.attn.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
111
+ "layers.13.attn.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
112
+ "layers.13.attn.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
113
+ "layers.13.attn.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
114
+ "layers.13.feed_forward.linear_1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
115
+ "layers.13.feed_forward.linear_2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
116
+ "layers.13.feed_forward.linear_3.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
117
+ "layers.13.ffn_norm1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
118
+ "layers.13.ffn_norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
119
+ "layers.13.norm1.linear.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
120
+ "layers.13.norm1.linear.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
121
+ "layers.13.norm1.norm.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
122
+ "layers.13.norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
123
+ "layers.14.attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
124
+ "layers.14.attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
125
+ "layers.14.attn.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
126
+ "layers.14.attn.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
127
+ "layers.14.attn.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
128
+ "layers.14.attn.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
129
+ "layers.14.feed_forward.linear_1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
130
+ "layers.14.feed_forward.linear_2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
131
+ "layers.14.feed_forward.linear_3.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
132
+ "layers.14.ffn_norm1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
133
+ "layers.14.ffn_norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
134
+ "layers.14.norm1.linear.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
135
+ "layers.14.norm1.linear.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
136
+ "layers.14.norm1.norm.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
137
+ "layers.14.norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
138
+ "layers.15.attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
139
+ "layers.15.attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
140
+ "layers.15.attn.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
141
+ "layers.15.attn.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
142
+ "layers.15.attn.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
143
+ "layers.15.attn.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
144
+ "layers.15.feed_forward.linear_1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
145
+ "layers.15.feed_forward.linear_2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
146
+ "layers.15.feed_forward.linear_3.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
147
+ "layers.15.ffn_norm1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
148
+ "layers.15.ffn_norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
149
+ "layers.15.norm1.linear.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
150
+ "layers.15.norm1.linear.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
151
+ "layers.15.norm1.norm.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
152
+ "layers.15.norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
153
+ "layers.16.attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
154
+ "layers.16.attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
155
+ "layers.16.attn.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
156
+ "layers.16.attn.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
157
+ "layers.16.attn.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
158
+ "layers.16.attn.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
159
+ "layers.16.feed_forward.linear_1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
160
+ "layers.16.feed_forward.linear_2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
161
+ "layers.16.feed_forward.linear_3.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
162
+ "layers.16.ffn_norm1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
163
+ "layers.16.ffn_norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
164
+ "layers.16.norm1.linear.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
165
+ "layers.16.norm1.linear.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
166
+ "layers.16.norm1.norm.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
167
+ "layers.16.norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
168
+ "layers.17.attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
169
+ "layers.17.attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
170
+ "layers.17.attn.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
171
+ "layers.17.attn.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
172
+ "layers.17.attn.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
173
+ "layers.17.attn.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
174
+ "layers.17.feed_forward.linear_1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
175
+ "layers.17.feed_forward.linear_2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
176
+ "layers.17.feed_forward.linear_3.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
177
+ "layers.17.ffn_norm1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
178
+ "layers.17.ffn_norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
179
+ "layers.17.norm1.linear.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
180
+ "layers.17.norm1.linear.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
181
+ "layers.17.norm1.norm.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
182
+ "layers.17.norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
183
+ "layers.18.attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
184
+ "layers.18.attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
185
+ "layers.18.attn.to_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
186
+ "layers.18.attn.to_out.0.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
187
+ "layers.18.attn.to_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
188
+ "layers.18.attn.to_v.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
189
+ "layers.18.feed_forward.linear_1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
190
+ "layers.18.feed_forward.linear_2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
191
+ "layers.18.feed_forward.linear_3.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
192
+ "layers.18.ffn_norm1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
193
+ "layers.18.ffn_norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
194
+ "layers.18.norm1.linear.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
195
+ "layers.18.norm1.linear.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
196
+ "layers.18.norm1.norm.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
197
+ "layers.18.norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
198
+ "layers.19.attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
199
+ "layers.19.attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
200
+ "layers.19.attn.to_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
201
+ "layers.19.attn.to_out.0.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
202
+ "layers.19.attn.to_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
203
+ "layers.19.attn.to_v.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
204
+ "layers.19.feed_forward.linear_1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
205
+ "layers.19.feed_forward.linear_2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
206
+ "layers.19.feed_forward.linear_3.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
207
+ "layers.19.ffn_norm1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
208
+ "layers.19.ffn_norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
209
+ "layers.19.norm1.linear.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
210
+ "layers.19.norm1.linear.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
211
+ "layers.19.norm1.norm.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
212
+ "layers.19.norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
213
+ "layers.2.attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
214
+ "layers.2.attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
215
+ "layers.2.attn.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
216
+ "layers.2.attn.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
217
+ "layers.2.attn.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
218
+ "layers.2.attn.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
219
+ "layers.2.feed_forward.linear_1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
220
+ "layers.2.feed_forward.linear_2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
221
+ "layers.2.feed_forward.linear_3.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
222
+ "layers.2.ffn_norm1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
223
+ "layers.2.ffn_norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
224
+ "layers.2.norm1.linear.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
225
+ "layers.2.norm1.linear.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
226
+ "layers.2.norm1.norm.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
227
+ "layers.2.norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
228
+ "layers.20.attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
229
+ "layers.20.attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
230
+ "layers.20.attn.to_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
231
+ "layers.20.attn.to_out.0.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
232
+ "layers.20.attn.to_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
233
+ "layers.20.attn.to_v.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
234
+ "layers.20.feed_forward.linear_1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
235
+ "layers.20.feed_forward.linear_2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
236
+ "layers.20.feed_forward.linear_3.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
237
+ "layers.20.ffn_norm1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
238
+ "layers.20.ffn_norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
239
+ "layers.20.norm1.linear.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
240
+ "layers.20.norm1.linear.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
241
+ "layers.20.norm1.norm.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
242
+ "layers.20.norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
243
+ "layers.21.attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
244
+ "layers.21.attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
245
+ "layers.21.attn.to_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
246
+ "layers.21.attn.to_out.0.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
247
+ "layers.21.attn.to_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
248
+ "layers.21.attn.to_v.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
249
+ "layers.21.feed_forward.linear_1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
250
+ "layers.21.feed_forward.linear_2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
251
+ "layers.21.feed_forward.linear_3.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
252
+ "layers.21.ffn_norm1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
253
+ "layers.21.ffn_norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
254
+ "layers.21.norm1.linear.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
255
+ "layers.21.norm1.linear.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
256
+ "layers.21.norm1.norm.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
257
+ "layers.21.norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
258
+ "layers.22.attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
259
+ "layers.22.attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
260
+ "layers.22.attn.to_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
261
+ "layers.22.attn.to_out.0.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
262
+ "layers.22.attn.to_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
263
+ "layers.22.attn.to_v.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
264
+ "layers.22.feed_forward.linear_1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
265
+ "layers.22.feed_forward.linear_2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
266
+ "layers.22.feed_forward.linear_3.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
267
+ "layers.22.ffn_norm1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
268
+ "layers.22.ffn_norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
269
+ "layers.22.norm1.linear.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
270
+ "layers.22.norm1.linear.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
271
+ "layers.22.norm1.norm.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
272
+ "layers.22.norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
273
+ "layers.23.attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
274
+ "layers.23.attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
275
+ "layers.23.attn.to_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
276
+ "layers.23.attn.to_out.0.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
277
+ "layers.23.attn.to_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
278
+ "layers.23.attn.to_v.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
279
+ "layers.23.feed_forward.linear_1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
280
+ "layers.23.feed_forward.linear_2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
281
+ "layers.23.feed_forward.linear_3.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
282
+ "layers.23.ffn_norm1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
283
+ "layers.23.ffn_norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
284
+ "layers.23.norm1.linear.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
285
+ "layers.23.norm1.linear.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
286
+ "layers.23.norm1.norm.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
287
+ "layers.23.norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
288
+ "layers.24.attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
289
+ "layers.24.attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
290
+ "layers.24.attn.to_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
291
+ "layers.24.attn.to_out.0.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
292
+ "layers.24.attn.to_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
293
+ "layers.24.attn.to_v.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
294
+ "layers.24.feed_forward.linear_1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
295
+ "layers.24.feed_forward.linear_2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
296
+ "layers.24.feed_forward.linear_3.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
297
+ "layers.24.ffn_norm1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
298
+ "layers.24.ffn_norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
299
+ "layers.24.norm1.linear.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
300
+ "layers.24.norm1.linear.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
301
+ "layers.24.norm1.norm.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
302
+ "layers.24.norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
303
+ "layers.25.attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
304
+ "layers.25.attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
305
+ "layers.25.attn.to_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
306
+ "layers.25.attn.to_out.0.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
307
+ "layers.25.attn.to_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
308
+ "layers.25.attn.to_v.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
309
+ "layers.25.feed_forward.linear_1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
310
+ "layers.25.feed_forward.linear_2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
311
+ "layers.25.feed_forward.linear_3.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
312
+ "layers.25.ffn_norm1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
313
+ "layers.25.ffn_norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
314
+ "layers.25.norm1.linear.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
315
+ "layers.25.norm1.linear.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
316
+ "layers.25.norm1.norm.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
317
+ "layers.25.norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
318
+ "layers.26.attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
319
+ "layers.26.attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
320
+ "layers.26.attn.to_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
321
+ "layers.26.attn.to_out.0.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
322
+ "layers.26.attn.to_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
323
+ "layers.26.attn.to_v.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
324
+ "layers.26.feed_forward.linear_1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
325
+ "layers.26.feed_forward.linear_2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
326
+ "layers.26.feed_forward.linear_3.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
327
+ "layers.26.ffn_norm1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
328
+ "layers.26.ffn_norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
329
+ "layers.26.norm1.linear.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
330
+ "layers.26.norm1.linear.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
331
+ "layers.26.norm1.norm.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
332
+ "layers.26.norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
333
+ "layers.27.attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
334
+ "layers.27.attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
335
+ "layers.27.attn.to_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
336
+ "layers.27.attn.to_out.0.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
337
+ "layers.27.attn.to_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
338
+ "layers.27.attn.to_v.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
339
+ "layers.27.feed_forward.linear_1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
340
+ "layers.27.feed_forward.linear_2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
341
+ "layers.27.feed_forward.linear_3.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
342
+ "layers.27.ffn_norm1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
343
+ "layers.27.ffn_norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
344
+ "layers.27.norm1.linear.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
345
+ "layers.27.norm1.linear.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
346
+ "layers.27.norm1.norm.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
347
+ "layers.27.norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
348
+ "layers.28.attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
349
+ "layers.28.attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
350
+ "layers.28.attn.to_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
351
+ "layers.28.attn.to_out.0.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
352
+ "layers.28.attn.to_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
353
+ "layers.28.attn.to_v.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
354
+ "layers.28.feed_forward.linear_1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
355
+ "layers.28.feed_forward.linear_2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
356
+ "layers.28.feed_forward.linear_3.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
357
+ "layers.28.ffn_norm1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
358
+ "layers.28.ffn_norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
359
+ "layers.28.norm1.linear.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
360
+ "layers.28.norm1.linear.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
361
+ "layers.28.norm1.norm.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
362
+ "layers.28.norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
363
+ "layers.29.attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
364
+ "layers.29.attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
365
+ "layers.29.attn.to_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
366
+ "layers.29.attn.to_out.0.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
367
+ "layers.29.attn.to_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
368
+ "layers.29.attn.to_v.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
369
+ "layers.29.feed_forward.linear_1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
370
+ "layers.29.feed_forward.linear_2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
371
+ "layers.29.feed_forward.linear_3.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
372
+ "layers.29.ffn_norm1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
373
+ "layers.29.ffn_norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
374
+ "layers.29.norm1.linear.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
375
+ "layers.29.norm1.linear.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
376
+ "layers.29.norm1.norm.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
377
+ "layers.29.norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
378
+ "layers.3.attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
379
+ "layers.3.attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
380
+ "layers.3.attn.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
381
+ "layers.3.attn.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
382
+ "layers.3.attn.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
383
+ "layers.3.attn.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
384
+ "layers.3.feed_forward.linear_1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
385
+ "layers.3.feed_forward.linear_2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
386
+ "layers.3.feed_forward.linear_3.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
387
+ "layers.3.ffn_norm1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
388
+ "layers.3.ffn_norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
389
+ "layers.3.norm1.linear.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
390
+ "layers.3.norm1.linear.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
391
+ "layers.3.norm1.norm.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
392
+ "layers.3.norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
393
+ "layers.30.attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
394
+ "layers.30.attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
395
+ "layers.30.attn.to_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
396
+ "layers.30.attn.to_out.0.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
397
+ "layers.30.attn.to_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
398
+ "layers.30.attn.to_v.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
399
+ "layers.30.feed_forward.linear_1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
400
+ "layers.30.feed_forward.linear_2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
401
+ "layers.30.feed_forward.linear_3.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
402
+ "layers.30.ffn_norm1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
403
+ "layers.30.ffn_norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
404
+ "layers.30.norm1.linear.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
405
+ "layers.30.norm1.linear.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
406
+ "layers.30.norm1.norm.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
407
+ "layers.30.norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
408
+ "layers.31.attn.norm_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
409
+ "layers.31.attn.norm_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
410
+ "layers.31.attn.to_k.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
411
+ "layers.31.attn.to_out.0.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
412
+ "layers.31.attn.to_q.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
413
+ "layers.31.attn.to_v.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
414
+ "layers.31.feed_forward.linear_1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
415
+ "layers.31.feed_forward.linear_2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
416
+ "layers.31.feed_forward.linear_3.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
417
+ "layers.31.ffn_norm1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
418
+ "layers.31.ffn_norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
419
+ "layers.31.norm1.linear.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
420
+ "layers.31.norm1.linear.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
421
+ "layers.31.norm1.norm.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
422
+ "layers.31.norm2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
423
+ "layers.4.attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
424
+ "layers.4.attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
425
+ "layers.4.attn.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
426
+ "layers.4.attn.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
427
+ "layers.4.attn.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
428
+ "layers.4.attn.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
429
+ "layers.4.feed_forward.linear_1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
430
+ "layers.4.feed_forward.linear_2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
431
+ "layers.4.feed_forward.linear_3.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
432
+ "layers.4.ffn_norm1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
433
+ "layers.4.ffn_norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
434
+ "layers.4.norm1.linear.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
435
+ "layers.4.norm1.linear.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
436
+ "layers.4.norm1.norm.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
437
+ "layers.4.norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
438
+ "layers.5.attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
439
+ "layers.5.attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
440
+ "layers.5.attn.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
441
+ "layers.5.attn.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
442
+ "layers.5.attn.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
443
+ "layers.5.attn.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
444
+ "layers.5.feed_forward.linear_1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
445
+ "layers.5.feed_forward.linear_2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
446
+ "layers.5.feed_forward.linear_3.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
447
+ "layers.5.ffn_norm1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
448
+ "layers.5.ffn_norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
449
+ "layers.5.norm1.linear.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
450
+ "layers.5.norm1.linear.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
451
+ "layers.5.norm1.norm.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
452
+ "layers.5.norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
453
+ "layers.6.attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
454
+ "layers.6.attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
455
+ "layers.6.attn.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
456
+ "layers.6.attn.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
457
+ "layers.6.attn.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
458
+ "layers.6.attn.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
459
+ "layers.6.feed_forward.linear_1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
460
+ "layers.6.feed_forward.linear_2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
461
+ "layers.6.feed_forward.linear_3.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
462
+ "layers.6.ffn_norm1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
463
+ "layers.6.ffn_norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
464
+ "layers.6.norm1.linear.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
465
+ "layers.6.norm1.linear.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
466
+ "layers.6.norm1.norm.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
467
+ "layers.6.norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
468
+ "layers.7.attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
469
+ "layers.7.attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
470
+ "layers.7.attn.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
471
+ "layers.7.attn.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
472
+ "layers.7.attn.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
473
+ "layers.7.attn.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
474
+ "layers.7.feed_forward.linear_1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
475
+ "layers.7.feed_forward.linear_2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
476
+ "layers.7.feed_forward.linear_3.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
477
+ "layers.7.ffn_norm1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
478
+ "layers.7.ffn_norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
479
+ "layers.7.norm1.linear.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
480
+ "layers.7.norm1.linear.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
481
+ "layers.7.norm1.norm.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
482
+ "layers.7.norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
483
+ "layers.8.attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
484
+ "layers.8.attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
485
+ "layers.8.attn.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
486
+ "layers.8.attn.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
487
+ "layers.8.attn.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
488
+ "layers.8.attn.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
489
+ "layers.8.feed_forward.linear_1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
490
+ "layers.8.feed_forward.linear_2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
491
+ "layers.8.feed_forward.linear_3.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
492
+ "layers.8.ffn_norm1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
493
+ "layers.8.ffn_norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
494
+ "layers.8.norm1.linear.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
495
+ "layers.8.norm1.linear.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
496
+ "layers.8.norm1.norm.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
497
+ "layers.8.norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
498
+ "layers.9.attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
499
+ "layers.9.attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
500
+ "layers.9.attn.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
501
+ "layers.9.attn.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
502
+ "layers.9.attn.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
503
+ "layers.9.attn.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
504
+ "layers.9.feed_forward.linear_1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
505
+ "layers.9.feed_forward.linear_2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
506
+ "layers.9.feed_forward.linear_3.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
507
+ "layers.9.ffn_norm1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
508
+ "layers.9.ffn_norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
509
+ "layers.9.norm1.linear.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
510
+ "layers.9.norm1.linear.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
511
+ "layers.9.norm1.norm.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
512
+ "layers.9.norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
513
+ "noise_refiner.0.attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
514
+ "noise_refiner.0.attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
515
+ "noise_refiner.0.attn.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
516
+ "noise_refiner.0.attn.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
517
+ "noise_refiner.0.attn.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
518
+ "noise_refiner.0.attn.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
519
+ "noise_refiner.0.feed_forward.linear_1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
520
+ "noise_refiner.0.feed_forward.linear_2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
521
+ "noise_refiner.0.feed_forward.linear_3.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
522
+ "noise_refiner.0.ffn_norm1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
523
+ "noise_refiner.0.ffn_norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
524
+ "noise_refiner.0.norm1.linear.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
525
+ "noise_refiner.0.norm1.linear.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
526
+ "noise_refiner.0.norm1.norm.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
527
+ "noise_refiner.0.norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
528
+ "noise_refiner.1.attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
529
+ "noise_refiner.1.attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
530
+ "noise_refiner.1.attn.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
531
+ "noise_refiner.1.attn.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
532
+ "noise_refiner.1.attn.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
533
+ "noise_refiner.1.attn.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
534
+ "noise_refiner.1.feed_forward.linear_1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
535
+ "noise_refiner.1.feed_forward.linear_2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
536
+ "noise_refiner.1.feed_forward.linear_3.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
537
+ "noise_refiner.1.ffn_norm1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
538
+ "noise_refiner.1.ffn_norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
539
+ "noise_refiner.1.norm1.linear.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
540
+ "noise_refiner.1.norm1.linear.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
541
+ "noise_refiner.1.norm1.norm.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
542
+ "noise_refiner.1.norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
543
+ "norm_out.linear_1.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
544
+ "norm_out.linear_1.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
545
+ "norm_out.linear_2.bias": "diffusion_pytorch_model-00002-of-00002.safetensors",
546
+ "norm_out.linear_2.weight": "diffusion_pytorch_model-00002-of-00002.safetensors",
547
+ "ref_image_patch_embedder.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
548
+ "ref_image_patch_embedder.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
549
+ "ref_image_refiner.0.attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
550
+ "ref_image_refiner.0.attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
551
+ "ref_image_refiner.0.attn.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
552
+ "ref_image_refiner.0.attn.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
553
+ "ref_image_refiner.0.attn.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
554
+ "ref_image_refiner.0.attn.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
555
+ "ref_image_refiner.0.feed_forward.linear_1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
556
+ "ref_image_refiner.0.feed_forward.linear_2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
557
+ "ref_image_refiner.0.feed_forward.linear_3.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
558
+ "ref_image_refiner.0.ffn_norm1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
559
+ "ref_image_refiner.0.ffn_norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
560
+ "ref_image_refiner.0.norm1.linear.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
561
+ "ref_image_refiner.0.norm1.linear.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
562
+ "ref_image_refiner.0.norm1.norm.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
563
+ "ref_image_refiner.0.norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
564
+ "ref_image_refiner.1.attn.norm_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
565
+ "ref_image_refiner.1.attn.norm_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
566
+ "ref_image_refiner.1.attn.to_k.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
567
+ "ref_image_refiner.1.attn.to_out.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
568
+ "ref_image_refiner.1.attn.to_q.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
569
+ "ref_image_refiner.1.attn.to_v.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
570
+ "ref_image_refiner.1.feed_forward.linear_1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
571
+ "ref_image_refiner.1.feed_forward.linear_2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
572
+ "ref_image_refiner.1.feed_forward.linear_3.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
573
+ "ref_image_refiner.1.ffn_norm1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
574
+ "ref_image_refiner.1.ffn_norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
575
+ "ref_image_refiner.1.norm1.linear.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
576
+ "ref_image_refiner.1.norm1.linear.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
577
+ "ref_image_refiner.1.norm1.norm.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
578
+ "ref_image_refiner.1.norm2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
579
+ "time_caption_embed.caption_embedder.0.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
580
+ "time_caption_embed.caption_embedder.1.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
581
+ "time_caption_embed.caption_embedder.1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
582
+ "time_caption_embed.timestep_embedder.linear_1.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
583
+ "time_caption_embed.timestep_embedder.linear_1.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
584
+ "time_caption_embed.timestep_embedder.linear_2.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
585
+ "time_caption_embed.timestep_embedder.linear_2.weight": "diffusion_pytorch_model-00001-of-00002.safetensors",
586
+ "x_embedder.bias": "diffusion_pytorch_model-00001-of-00002.safetensors",
587
+ "x_embedder.weight": "diffusion_pytorch_model-00001-of-00002.safetensors"
588
+ }
589
+ }
transformer/transformer_omnigen2.py ADDED
@@ -0,0 +1,2104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ import itertools
3
+ from typing import Any, Dict, List, Optional, Tuple, Union
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from einops import rearrange, repeat
11
+
12
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
13
+ from diffusers.loaders import PeftAdapterMixin
14
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
15
+ from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
16
+ from diffusers.models.attention_processor import Attention
17
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
18
+ from diffusers.models.modeling_utils import ModelMixin
19
+ from diffusers.models.embeddings import get_1d_rotary_pos_embed
20
+ from diffusers.models.activations import get_activation
21
+ from diffusers.models.embeddings import Timesteps
22
+
23
+ import importlib.util
24
+ import sys
25
+
26
+ # The package importlib_metadata is in a different place, depending on the python version.
27
+ if sys.version_info < (3, 8):
28
+ import importlib_metadata
29
+ else:
30
+ import importlib.metadata as importlib_metadata
31
+
32
+ def _is_package_available(pkg_name: str):
33
+ pkg_exists = importlib.util.find_spec(pkg_name) is not None
34
+ pkg_version = "N/A"
35
+
36
+ if pkg_exists:
37
+ try:
38
+ pkg_version = importlib_metadata.version(pkg_name)
39
+ except (ImportError, importlib_metadata.PackageNotFoundError):
40
+ pkg_exists = False
41
+
42
+ return pkg_exists, pkg_version
43
+
44
+ _triton_available, _triton_version = _is_package_available("triton")
45
+ _flash_attn_available, _flash_attn_version = _is_package_available("flash_attn")
46
+
47
+ def is_triton_available():
48
+ return _triton_available
49
+
50
+ def is_flash_attn_available():
51
+ return _flash_attn_available
52
+
53
+ if is_triton_available():
54
+ # from ...ops.triton.layer_norm import RMSNorm
55
+ import triton
56
+ import triton.language as tl
57
+
58
+
59
+ from typing import Callable
60
+
61
+
62
+ def custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool):
63
+ def decorator(*args, **kwargs):
64
+ if cuda_amp_deprecated:
65
+ kwargs["device_type"] = "cuda"
66
+ return dec(*args, **kwargs)
67
+ return decorator
68
+
69
+
70
+ if hasattr(torch.amp, "custom_fwd"): # type: ignore[attr-defined]
71
+ deprecated = True
72
+ from torch.amp import custom_fwd, custom_bwd # type: ignore[attr-defined]
73
+ else:
74
+ deprecated = False
75
+ from torch.cuda.amp import custom_fwd, custom_bwd
76
+
77
+ custom_fwd = custom_amp_decorator(custom_fwd, deprecated)
78
+ custom_bwd = custom_amp_decorator(custom_bwd, deprecated)
79
+
80
+
81
+ def triton_autotune_configs():
82
+ # Return configs with a valid warp count for the current device
83
+ configs=[]
84
+ # Maximum threads per block is architecture-dependent in theory, but in reality all are 1024
85
+ max_threads_per_block=1024
86
+ # Default to warp size 32 if not defined by device
87
+ warp_size=getattr(torch.cuda.get_device_properties(torch.cuda.current_device()), "warp_size", 32)
88
+ # Autotune for warp counts which are powers of 2 and do not exceed thread per block limit
89
+ warp_count=1
90
+ while warp_count*warp_size <= max_threads_per_block:
91
+ configs.append(triton.Config({}, num_warps=warp_count))
92
+ warp_count*=2
93
+ return configs
94
+
95
+ @triton.autotune(
96
+ configs=triton_autotune_configs(),
97
+ key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
98
+ )
99
+ # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
100
+ # @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
101
+ @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None})
102
+ @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None})
103
+ @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None})
104
+ @triton.jit
105
+ def _layer_norm_fwd_1pass_kernel(
106
+ X, # pointer to the input
107
+ Y, # pointer to the output
108
+ W, # pointer to the weights
109
+ B, # pointer to the biases
110
+ RESIDUAL, # pointer to the residual
111
+ X1,
112
+ W1,
113
+ B1,
114
+ Y1,
115
+ RESIDUAL_OUT, # pointer to the residual
116
+ ROWSCALE,
117
+ SEEDS, # Dropout seeds for each row
118
+ DROPOUT_MASK,
119
+ Mean, # pointer to the mean
120
+ Rstd, # pointer to the 1/std
121
+ stride_x_row, # how much to increase the pointer when moving by 1 row
122
+ stride_y_row,
123
+ stride_res_row,
124
+ stride_res_out_row,
125
+ stride_x1_row,
126
+ stride_y1_row,
127
+ M, # number of rows in X
128
+ N, # number of columns in X
129
+ eps, # epsilon to avoid division by zero
130
+ dropout_p, # Dropout probability
131
+ zero_centered_weight, # If true, add 1.0 to the weight
132
+ IS_RMS_NORM: tl.constexpr,
133
+ BLOCK_N: tl.constexpr,
134
+ HAS_RESIDUAL: tl.constexpr,
135
+ STORE_RESIDUAL_OUT: tl.constexpr,
136
+ HAS_BIAS: tl.constexpr,
137
+ HAS_DROPOUT: tl.constexpr,
138
+ STORE_DROPOUT_MASK: tl.constexpr,
139
+ HAS_ROWSCALE: tl.constexpr,
140
+ HAS_X1: tl.constexpr,
141
+ HAS_W1: tl.constexpr,
142
+ HAS_B1: tl.constexpr,
143
+ ):
144
+ # Map the program id to the row of X and Y it should compute.
145
+ row = tl.program_id(0)
146
+ X += row * stride_x_row
147
+ Y += row * stride_y_row
148
+ if HAS_RESIDUAL:
149
+ RESIDUAL += row * stride_res_row
150
+ if STORE_RESIDUAL_OUT:
151
+ RESIDUAL_OUT += row * stride_res_out_row
152
+ if HAS_X1:
153
+ X1 += row * stride_x1_row
154
+ if HAS_W1:
155
+ Y1 += row * stride_y1_row
156
+ # Compute mean and variance
157
+ cols = tl.arange(0, BLOCK_N)
158
+ x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
159
+ if HAS_ROWSCALE:
160
+ rowscale = tl.load(ROWSCALE + row).to(tl.float32)
161
+ x *= rowscale
162
+ if HAS_DROPOUT:
163
+ # Compute dropout mask
164
+ # 7 rounds is good enough, and reduces register pressure
165
+ keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
166
+ x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
167
+ if STORE_DROPOUT_MASK:
168
+ tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
169
+ if HAS_X1:
170
+ x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32)
171
+ if HAS_ROWSCALE:
172
+ rowscale = tl.load(ROWSCALE + M + row).to(tl.float32)
173
+ x1 *= rowscale
174
+ if HAS_DROPOUT:
175
+ # Compute dropout mask
176
+ # 7 rounds is good enough, and reduces register pressure
177
+ keep_mask = (
178
+ tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
179
+ )
180
+ x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)
181
+ if STORE_DROPOUT_MASK:
182
+ tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N)
183
+ x += x1
184
+ if HAS_RESIDUAL:
185
+ residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
186
+ x += residual
187
+ if STORE_RESIDUAL_OUT:
188
+ tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
189
+ if not IS_RMS_NORM:
190
+ mean = tl.sum(x, axis=0) / N
191
+ tl.store(Mean + row, mean)
192
+ xbar = tl.where(cols < N, x - mean, 0.0)
193
+ var = tl.sum(xbar * xbar, axis=0) / N
194
+ else:
195
+ xbar = tl.where(cols < N, x, 0.0)
196
+ var = tl.sum(xbar * xbar, axis=0) / N
197
+ rstd = 1 / tl.sqrt(var + eps)
198
+ tl.store(Rstd + row, rstd)
199
+ # Normalize and apply linear transformation
200
+ mask = cols < N
201
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
202
+ if zero_centered_weight:
203
+ w += 1.0
204
+ if HAS_BIAS:
205
+ b = tl.load(B + cols, mask=mask).to(tl.float32)
206
+ x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
207
+ y = x_hat * w + b if HAS_BIAS else x_hat * w
208
+ # Write output
209
+ tl.store(Y + cols, y, mask=mask)
210
+ if HAS_W1:
211
+ w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
212
+ if zero_centered_weight:
213
+ w1 += 1.0
214
+ if HAS_B1:
215
+ b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)
216
+ y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1
217
+ tl.store(Y1 + cols, y1, mask=mask)
218
+
219
+
220
+ def _layer_norm_fwd(
221
+ x,
222
+ weight,
223
+ bias,
224
+ eps,
225
+ residual=None,
226
+ x1=None,
227
+ weight1=None,
228
+ bias1=None,
229
+ dropout_p=0.0,
230
+ rowscale=None,
231
+ out_dtype=None,
232
+ residual_dtype=None,
233
+ zero_centered_weight=False,
234
+ is_rms_norm=False,
235
+ return_dropout_mask=False,
236
+ out=None,
237
+ residual_out=None
238
+ ):
239
+ if residual is not None:
240
+ residual_dtype = residual.dtype
241
+ M, N = x.shape
242
+ assert x.stride(-1) == 1
243
+ if residual is not None:
244
+ assert residual.stride(-1) == 1
245
+ assert residual.shape == (M, N)
246
+ assert weight.shape == (N,)
247
+ assert weight.stride(-1) == 1
248
+ if bias is not None:
249
+ assert bias.stride(-1) == 1
250
+ assert bias.shape == (N,)
251
+ if x1 is not None:
252
+ assert x1.shape == x.shape
253
+ assert rowscale is None
254
+ assert x1.stride(-1) == 1
255
+ if weight1 is not None:
256
+ assert weight1.shape == (N,)
257
+ assert weight1.stride(-1) == 1
258
+ if bias1 is not None:
259
+ assert bias1.shape == (N,)
260
+ assert bias1.stride(-1) == 1
261
+ if rowscale is not None:
262
+ assert rowscale.is_contiguous()
263
+ assert rowscale.shape == (M,)
264
+ # allocate output
265
+ if out is None:
266
+ out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
267
+ else:
268
+ assert out.shape == x.shape
269
+ assert out.stride(-1) == 1
270
+ if weight1 is not None:
271
+ y1 = torch.empty_like(out)
272
+ assert y1.stride(-1) == 1
273
+ else:
274
+ y1 = None
275
+ if (
276
+ residual is not None
277
+ or (residual_dtype is not None and residual_dtype != x.dtype)
278
+ or dropout_p > 0.0
279
+ or rowscale is not None
280
+ or x1 is not None
281
+ ):
282
+ if residual_out is None:
283
+ residual_out = torch.empty(
284
+ M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype
285
+ )
286
+ else:
287
+ assert residual_out.shape == x.shape
288
+ assert residual_out.stride(-1) == 1
289
+ else:
290
+ residual_out = None
291
+ mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None
292
+ rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
293
+ if dropout_p > 0.0:
294
+ seeds = torch.randint(
295
+ 2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64
296
+ )
297
+ else:
298
+ seeds = None
299
+ if return_dropout_mask and dropout_p > 0.0:
300
+ dropout_mask = torch.empty(M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool)
301
+ else:
302
+ dropout_mask = None
303
+ # Less than 64KB per feature: enqueue fused kernel
304
+ MAX_FUSED_SIZE = 65536 // x.element_size()
305
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
306
+ if N > BLOCK_N:
307
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
308
+ with torch.cuda.device(x.device.index):
309
+ _layer_norm_fwd_1pass_kernel[(M,)](
310
+ x,
311
+ out,
312
+ weight,
313
+ bias,
314
+ residual,
315
+ x1,
316
+ weight1,
317
+ bias1,
318
+ y1,
319
+ residual_out,
320
+ rowscale,
321
+ seeds,
322
+ dropout_mask,
323
+ mean,
324
+ rstd,
325
+ x.stride(0),
326
+ out.stride(0),
327
+ residual.stride(0) if residual is not None else 0,
328
+ residual_out.stride(0) if residual_out is not None else 0,
329
+ x1.stride(0) if x1 is not None else 0,
330
+ y1.stride(0) if y1 is not None else 0,
331
+ M,
332
+ N,
333
+ eps,
334
+ dropout_p,
335
+ zero_centered_weight,
336
+ is_rms_norm,
337
+ BLOCK_N,
338
+ residual is not None,
339
+ residual_out is not None,
340
+ bias is not None,
341
+ dropout_p > 0.0,
342
+ dropout_mask is not None,
343
+ rowscale is not None,
344
+ )
345
+ # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
346
+ if dropout_mask is not None and x1 is not None:
347
+ dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0)
348
+ else:
349
+ dropout_mask1 = None
350
+ return (
351
+ out,
352
+ y1,
353
+ mean,
354
+ rstd,
355
+ residual_out if residual_out is not None else x,
356
+ seeds,
357
+ dropout_mask,
358
+ dropout_mask1,
359
+ )
360
+
361
+ @triton.autotune(
362
+ configs=triton_autotune_configs(),
363
+ key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"],
364
+ )
365
+ # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
366
+ # @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
367
+ # @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
368
+ @triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None})
369
+ @triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None})
370
+ @triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None})
371
+ @triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None})
372
+ @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
373
+ @triton.jit
374
+ def _layer_norm_bwd_kernel(
375
+ X, # pointer to the input
376
+ W, # pointer to the weights
377
+ B, # pointer to the biases
378
+ Y, # pointer to the output to be recomputed
379
+ DY, # pointer to the output gradient
380
+ DX, # pointer to the input gradient
381
+ DW, # pointer to the partial sum of weights gradient
382
+ DB, # pointer to the partial sum of biases gradient
383
+ DRESIDUAL,
384
+ W1,
385
+ DY1,
386
+ DX1,
387
+ DW1,
388
+ DB1,
389
+ DRESIDUAL_IN,
390
+ ROWSCALE,
391
+ SEEDS,
392
+ Mean, # pointer to the mean
393
+ Rstd, # pointer to the 1/std
394
+ stride_x_row, # how much to increase the pointer when moving by 1 row
395
+ stride_y_row,
396
+ stride_dy_row,
397
+ stride_dx_row,
398
+ stride_dres_row,
399
+ stride_dy1_row,
400
+ stride_dx1_row,
401
+ stride_dres_in_row,
402
+ M, # number of rows in X
403
+ N, # number of columns in X
404
+ eps, # epsilon to avoid division by zero
405
+ dropout_p,
406
+ zero_centered_weight,
407
+ rows_per_program,
408
+ IS_RMS_NORM: tl.constexpr,
409
+ BLOCK_N: tl.constexpr,
410
+ HAS_DRESIDUAL: tl.constexpr,
411
+ STORE_DRESIDUAL: tl.constexpr,
412
+ HAS_BIAS: tl.constexpr,
413
+ HAS_DROPOUT: tl.constexpr,
414
+ HAS_ROWSCALE: tl.constexpr,
415
+ HAS_DY1: tl.constexpr,
416
+ HAS_DX1: tl.constexpr,
417
+ HAS_B1: tl.constexpr,
418
+ RECOMPUTE_OUTPUT: tl.constexpr,
419
+ ):
420
+ # Map the program id to the elements of X, DX, and DY it should compute.
421
+ row_block_id = tl.program_id(0)
422
+ row_start = row_block_id * rows_per_program
423
+ # Do not early exit if row_start >= M, because we need to write DW and DB
424
+ cols = tl.arange(0, BLOCK_N)
425
+ mask = cols < N
426
+ X += row_start * stride_x_row
427
+ if HAS_DRESIDUAL:
428
+ DRESIDUAL += row_start * stride_dres_row
429
+ if STORE_DRESIDUAL:
430
+ DRESIDUAL_IN += row_start * stride_dres_in_row
431
+ DY += row_start * stride_dy_row
432
+ DX += row_start * stride_dx_row
433
+ if HAS_DY1:
434
+ DY1 += row_start * stride_dy1_row
435
+ if HAS_DX1:
436
+ DX1 += row_start * stride_dx1_row
437
+ if RECOMPUTE_OUTPUT:
438
+ Y += row_start * stride_y_row
439
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
440
+ if zero_centered_weight:
441
+ w += 1.0
442
+ if RECOMPUTE_OUTPUT and HAS_BIAS:
443
+ b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
444
+ if HAS_DY1:
445
+ w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
446
+ if zero_centered_weight:
447
+ w1 += 1.0
448
+ dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
449
+ if HAS_BIAS:
450
+ db = tl.zeros((BLOCK_N,), dtype=tl.float32)
451
+ if HAS_DY1:
452
+ dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
453
+ if HAS_B1:
454
+ db1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
455
+ row_end = min((row_block_id + 1) * rows_per_program, M)
456
+ for row in range(row_start, row_end):
457
+ # Load data to SRAM
458
+ x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
459
+ dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
460
+ if HAS_DY1:
461
+ dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32)
462
+ if not IS_RMS_NORM:
463
+ mean = tl.load(Mean + row)
464
+ rstd = tl.load(Rstd + row)
465
+ # Compute dx
466
+ xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
467
+ xhat = tl.where(mask, xhat, 0.0)
468
+ if RECOMPUTE_OUTPUT:
469
+ y = xhat * w + b if HAS_BIAS else xhat * w
470
+ tl.store(Y + cols, y, mask=mask)
471
+ wdy = w * dy
472
+ dw += dy * xhat
473
+ if HAS_BIAS:
474
+ db += dy
475
+ if HAS_DY1:
476
+ wdy += w1 * dy1
477
+ dw1 += dy1 * xhat
478
+ if HAS_B1:
479
+ db1 += dy1
480
+ if not IS_RMS_NORM:
481
+ c1 = tl.sum(xhat * wdy, axis=0) / N
482
+ c2 = tl.sum(wdy, axis=0) / N
483
+ dx = (wdy - (xhat * c1 + c2)) * rstd
484
+ else:
485
+ c1 = tl.sum(xhat * wdy, axis=0) / N
486
+ dx = (wdy - xhat * c1) * rstd
487
+ if HAS_DRESIDUAL:
488
+ dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
489
+ dx += dres
490
+ # Write dx
491
+ if STORE_DRESIDUAL:
492
+ tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
493
+ if HAS_DX1:
494
+ if HAS_DROPOUT:
495
+ keep_mask = (
496
+ tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
497
+ )
498
+ dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
499
+ else:
500
+ dx1 = dx
501
+ tl.store(DX1 + cols, dx1, mask=mask)
502
+ if HAS_DROPOUT:
503
+ keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
504
+ dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
505
+ if HAS_ROWSCALE:
506
+ rowscale = tl.load(ROWSCALE + row).to(tl.float32)
507
+ dx *= rowscale
508
+ tl.store(DX + cols, dx, mask=mask)
509
+
510
+ X += stride_x_row
511
+ if HAS_DRESIDUAL:
512
+ DRESIDUAL += stride_dres_row
513
+ if STORE_DRESIDUAL:
514
+ DRESIDUAL_IN += stride_dres_in_row
515
+ if RECOMPUTE_OUTPUT:
516
+ Y += stride_y_row
517
+ DY += stride_dy_row
518
+ DX += stride_dx_row
519
+ if HAS_DY1:
520
+ DY1 += stride_dy1_row
521
+ if HAS_DX1:
522
+ DX1 += stride_dx1_row
523
+ tl.store(DW + row_block_id * N + cols, dw, mask=mask)
524
+ if HAS_BIAS:
525
+ tl.store(DB + row_block_id * N + cols, db, mask=mask)
526
+ if HAS_DY1:
527
+ tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask)
528
+ if HAS_B1:
529
+ tl.store(DB1 + row_block_id * N + cols, db1, mask=mask)
530
+
531
+
532
+ def _layer_norm_bwd(
533
+ dy,
534
+ x,
535
+ weight,
536
+ bias,
537
+ eps,
538
+ mean,
539
+ rstd,
540
+ dresidual=None,
541
+ dy1=None,
542
+ weight1=None,
543
+ bias1=None,
544
+ seeds=None,
545
+ dropout_p=0.0,
546
+ rowscale=None,
547
+ has_residual=False,
548
+ has_x1=False,
549
+ zero_centered_weight=False,
550
+ is_rms_norm=False,
551
+ x_dtype=None,
552
+ recompute_output=False,
553
+ ):
554
+ M, N = x.shape
555
+ assert x.stride(-1) == 1
556
+ assert dy.stride(-1) == 1
557
+ assert dy.shape == (M, N)
558
+ if dresidual is not None:
559
+ assert dresidual.stride(-1) == 1
560
+ assert dresidual.shape == (M, N)
561
+ assert weight.shape == (N,)
562
+ assert weight.stride(-1) == 1
563
+ if bias is not None:
564
+ assert bias.stride(-1) == 1
565
+ assert bias.shape == (N,)
566
+ if dy1 is not None:
567
+ assert weight1 is not None
568
+ assert dy1.shape == dy.shape
569
+ assert dy1.stride(-1) == 1
570
+ if weight1 is not None:
571
+ assert weight1.shape == (N,)
572
+ assert weight1.stride(-1) == 1
573
+ if bias1 is not None:
574
+ assert bias1.shape == (N,)
575
+ assert bias1.stride(-1) == 1
576
+ if seeds is not None:
577
+ assert seeds.is_contiguous()
578
+ assert seeds.shape == (M if not has_x1 else M * 2,)
579
+ if rowscale is not None:
580
+ assert rowscale.is_contiguous()
581
+ assert rowscale.shape == (M,)
582
+ # allocate output
583
+ dx = (
584
+ torch.empty_like(x)
585
+ if x_dtype is None
586
+ else torch.empty(M, N, dtype=x_dtype, device=x.device)
587
+ )
588
+ dresidual_in = (
589
+ torch.empty_like(x)
590
+ if has_residual
591
+ and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1)
592
+ else None
593
+ )
594
+ dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None
595
+ y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None
596
+ if recompute_output:
597
+ assert weight1 is None, "recompute_output is not supported with parallel LayerNorm"
598
+
599
+ # Less than 64KB per feature: enqueue fused kernel
600
+ MAX_FUSED_SIZE = 65536 // x.element_size()
601
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
602
+ if N > BLOCK_N:
603
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
604
+ # Increasing the multiple (e.g. 8) will allow more thread blocks to be launched and hide the
605
+ # latency of the gmem reads/writes, but will increase the time of summing up dw / db.
606
+ sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count * 8
607
+ _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
608
+ _db = (
609
+ torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
610
+ if bias is not None
611
+ else None
612
+ )
613
+ _dw1 = torch.empty_like(_dw) if weight1 is not None else None
614
+ _db1 = torch.empty_like(_db) if bias1 is not None else None
615
+ rows_per_program = math.ceil(M / sm_count)
616
+ grid = (sm_count,)
617
+ with torch.cuda.device(x.device.index):
618
+ _layer_norm_bwd_kernel[grid](
619
+ x,
620
+ weight,
621
+ bias,
622
+ y,
623
+ dy,
624
+ dx,
625
+ _dw,
626
+ _db,
627
+ dresidual,
628
+ weight1,
629
+ dy1,
630
+ dx1,
631
+ _dw1,
632
+ _db1,
633
+ dresidual_in,
634
+ rowscale,
635
+ seeds,
636
+ mean,
637
+ rstd,
638
+ x.stride(0),
639
+ 0 if not recompute_output else y.stride(0),
640
+ dy.stride(0),
641
+ dx.stride(0),
642
+ dresidual.stride(0) if dresidual is not None else 0,
643
+ dy1.stride(0) if dy1 is not None else 0,
644
+ dx1.stride(0) if dx1 is not None else 0,
645
+ dresidual_in.stride(0) if dresidual_in is not None else 0,
646
+ M,
647
+ N,
648
+ eps,
649
+ dropout_p,
650
+ zero_centered_weight,
651
+ rows_per_program,
652
+ is_rms_norm,
653
+ BLOCK_N,
654
+ dresidual is not None,
655
+ dresidual_in is not None,
656
+ bias is not None,
657
+ dropout_p > 0.0,
658
+ )
659
+ dw = _dw.sum(0).to(weight.dtype)
660
+ db = _db.sum(0).to(bias.dtype) if bias is not None else None
661
+ dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None
662
+ db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None
663
+ # Don't need to compute dresidual_in separately in this case
664
+ if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None:
665
+ dresidual_in = dx
666
+ if has_x1 and dropout_p == 0.0:
667
+ dx1 = dx
668
+ return (
669
+ (dx, dw, db, dresidual_in, dx1, dw1, db1)
670
+ if not recompute_output
671
+ else (dx, dw, db, dresidual_in, dx1, dw1, db1, y)
672
+ )
673
+
674
+ class LayerNormFn(torch.autograd.Function):
675
+ @staticmethod
676
+ def forward(
677
+ ctx,
678
+ x,
679
+ weight,
680
+ bias,
681
+ residual=None,
682
+ x1=None,
683
+ weight1=None,
684
+ bias1=None,
685
+ eps=1e-6,
686
+ dropout_p=0.0,
687
+ rowscale=None,
688
+ prenorm=False,
689
+ residual_in_fp32=False,
690
+ zero_centered_weight=False,
691
+ is_rms_norm=False,
692
+ return_dropout_mask=False,
693
+ out=None,
694
+ residual_out=None
695
+ ):
696
+ x_shape_og = x.shape
697
+ # Check for zero sequence length
698
+ if x.numel() == 0:
699
+ ctx.zero_seq_length = True
700
+ # Only save minimal required tensors for backward
701
+ # ctx.save_for_backward(weight, bias, weight1, bias1)
702
+ ctx.x_shape_og = x_shape_og
703
+ ctx.weight_shape = weight.shape
704
+ ctx.weight_dtype = weight.dtype
705
+ ctx.weight_device = weight.device
706
+
707
+ ctx.has_bias = bias is not None
708
+ ctx.bias_shape = bias.shape if bias is not None else None
709
+ ctx.bias_dtype = bias.dtype if bias is not None else None
710
+ ctx.bias_device = bias.device if bias is not None else None
711
+
712
+ ctx.has_weight1 = weight1 is not None
713
+ ctx.weight1_shape = weight1.shape if weight1 is not None else None
714
+ ctx.weight1_dtype = weight1.dtype if weight1 is not None else None
715
+ ctx.weight1_device = weight1.device if weight1 is not None else None
716
+
717
+ ctx.has_bias1 = bias1 is not None
718
+ ctx.bias1_shape = bias1.shape if bias1 is not None else None
719
+ ctx.bias1_dtype = bias1.dtype if bias1 is not None else None
720
+ ctx.bias1_device = bias1.device if bias1 is not None else None
721
+
722
+ ctx.has_residual = residual is not None
723
+ ctx.has_x1 = x1 is not None
724
+ ctx.dropout_p = dropout_p
725
+
726
+ # Handle output tensors with correct dtype
727
+ y = x # Preserve input tensor properties
728
+ y1 = torch.empty_like(x) if x1 is not None else None
729
+
730
+ # Only create residual_out if prenorm is True
731
+ residual_out = torch.empty(x.shape,
732
+ dtype=torch.float32 if residual_in_fp32 else x.dtype,
733
+ device=x.device) if prenorm else None
734
+
735
+ # Handle dropout masks
736
+ dropout_mask = None
737
+ dropout_mask1 = None
738
+ if return_dropout_mask:
739
+ dropout_mask = torch.empty_like(x, dtype=torch.uint8)
740
+ if x1 is not None:
741
+ dropout_mask1 = torch.empty_like(x, dtype=torch.uint8)
742
+
743
+ # Return based on configuration
744
+ if not return_dropout_mask:
745
+ if weight1 is None:
746
+ return y if not prenorm else (y, residual_out)
747
+ else:
748
+ return (y, y1) if not prenorm else (y, y1, residual_out)
749
+ else:
750
+ if weight1 is None:
751
+ return ((y, dropout_mask, dropout_mask1) if not prenorm
752
+ else (y, residual_out, dropout_mask, dropout_mask1))
753
+ else:
754
+ return ((y, y1, dropout_mask, dropout_mask1) if not prenorm
755
+ else (y, y1, residual_out, dropout_mask, dropout_mask1))
756
+
757
+ ctx.zero_seq_length = False
758
+ # reshape input data into 2D tensor
759
+ x = x.reshape(-1, x.shape[-1])
760
+ if x.stride(-1) != 1:
761
+ x = x.contiguous()
762
+ if residual is not None:
763
+ assert residual.shape == x_shape_og
764
+ residual = residual.reshape(-1, residual.shape[-1])
765
+ if residual.stride(-1) != 1:
766
+ residual = residual.contiguous()
767
+ if x1 is not None:
768
+ assert x1.shape == x_shape_og
769
+ assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
770
+ x1 = x1.reshape(-1, x1.shape[-1])
771
+ if x1.stride(-1) != 1:
772
+ x1 = x1.contiguous()
773
+ weight = weight.contiguous()
774
+ if bias is not None:
775
+ bias = bias.contiguous()
776
+ if weight1 is not None:
777
+ weight1 = weight1.contiguous()
778
+ if bias1 is not None:
779
+ bias1 = bias1.contiguous()
780
+ if rowscale is not None:
781
+ rowscale = rowscale.reshape(-1).contiguous()
782
+ residual_dtype = (
783
+ residual.dtype
784
+ if residual is not None
785
+ else (torch.float32 if residual_in_fp32 else None)
786
+ )
787
+ if out is not None:
788
+ out = out.reshape(-1, out.shape[-1])
789
+ if residual_out is not None:
790
+ residual_out = residual_out.reshape(-1, residual_out.shape[-1])
791
+ y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd(
792
+ x,
793
+ weight,
794
+ bias,
795
+ eps,
796
+ residual,
797
+ x1,
798
+ weight1,
799
+ bias1,
800
+ dropout_p=dropout_p,
801
+ rowscale=rowscale,
802
+ residual_dtype=residual_dtype,
803
+ zero_centered_weight=zero_centered_weight,
804
+ is_rms_norm=is_rms_norm,
805
+ return_dropout_mask=return_dropout_mask,
806
+ out=out,
807
+ residual_out=residual_out
808
+ )
809
+ ctx.save_for_backward(
810
+ residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
811
+ )
812
+ ctx.x_shape_og = x_shape_og
813
+ ctx.eps = eps
814
+ ctx.dropout_p = dropout_p
815
+ ctx.is_rms_norm = is_rms_norm
816
+ ctx.has_residual = residual is not None
817
+ ctx.has_x1 = x1 is not None
818
+ ctx.prenorm = prenorm
819
+ ctx.x_dtype = x.dtype
820
+ ctx.zero_centered_weight = zero_centered_weight
821
+ y = y.reshape(x_shape_og)
822
+ y1 = y1.reshape(x_shape_og) if y1 is not None else None
823
+ residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None
824
+ dropout_mask = dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None
825
+ dropout_mask1 = dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None
826
+ if not return_dropout_mask:
827
+ if weight1 is None:
828
+ return y if not prenorm else (y, residual_out)
829
+ else:
830
+ return (y, y1) if not prenorm else (y, y1, residual_out)
831
+ else:
832
+ if weight1 is None:
833
+ return (
834
+ (y, dropout_mask, dropout_mask1)
835
+ if not prenorm
836
+ else (y, residual_out, dropout_mask, dropout_mask1)
837
+ )
838
+ else:
839
+ return (
840
+ (y, y1, dropout_mask, dropout_mask1)
841
+ if not prenorm
842
+ else (y, y1, residual_out, dropout_mask, dropout_mask1)
843
+ )
844
+
845
+ @staticmethod
846
+ def backward(ctx, dy, *args):
847
+ if ctx.zero_seq_length:
848
+ return (
849
+ torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device),
850
+ torch.zeros(ctx.weight_shape, dtype=ctx.weight_dtype, device=ctx.weight_device),
851
+ torch.zeros(ctx.bias_shape, dtype=ctx.bias_dtype, device=ctx.bias_device) if ctx.has_bias else None,
852
+ torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device) if ctx.has_residual else None,
853
+ torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device) if ctx.has_x1 and ctx.dropout_p > 0.0 else None,
854
+ torch.zeros(ctx.weight1_shape, dtype=ctx.weight1_dtype, device=ctx.weight1_device) if ctx.has_weight1 else None,
855
+ torch.zeros(ctx.bias1_shape, dtype=ctx.bias1_dtype, device=ctx.bias1_device) if ctx.has_bias1 else None,
856
+ None,
857
+ None,
858
+ None,
859
+ None,
860
+ None,
861
+ None,
862
+ None,
863
+ None,
864
+ None,
865
+ None,
866
+ )
867
+
868
+ x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors
869
+ dy = dy.reshape(-1, dy.shape[-1])
870
+ if dy.stride(-1) != 1:
871
+ dy = dy.contiguous()
872
+ assert dy.shape == x.shape
873
+ if weight1 is not None:
874
+ dy1, args = args[0], args[1:]
875
+ dy1 = dy1.reshape(-1, dy1.shape[-1])
876
+ if dy1.stride(-1) != 1:
877
+ dy1 = dy1.contiguous()
878
+ assert dy1.shape == x.shape
879
+ else:
880
+ dy1 = None
881
+ if ctx.prenorm:
882
+ dresidual = args[0]
883
+ dresidual = dresidual.reshape(-1, dresidual.shape[-1])
884
+ if dresidual.stride(-1) != 1:
885
+ dresidual = dresidual.contiguous()
886
+ assert dresidual.shape == x.shape
887
+ else:
888
+ dresidual = None
889
+
890
+ dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd(
891
+ dy,
892
+ x,
893
+ weight,
894
+ bias,
895
+ ctx.eps,
896
+ mean,
897
+ rstd,
898
+ dresidual,
899
+ dy1,
900
+ weight1,
901
+ bias1,
902
+ seeds,
903
+ ctx.dropout_p,
904
+ rowscale,
905
+ ctx.has_residual,
906
+ ctx.has_x1,
907
+ ctx.zero_centered_weight,
908
+ ctx.is_rms_norm,
909
+ x_dtype=ctx.x_dtype,
910
+ )
911
+ return (
912
+ dx.reshape(ctx.x_shape_og),
913
+ dw,
914
+ db,
915
+ dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
916
+ dx1.reshape(ctx.x_shape_og) if dx1 is not None else None,
917
+ dw1,
918
+ db1,
919
+ None,
920
+ None,
921
+ None,
922
+ None,
923
+ None,
924
+ None,
925
+ None,
926
+ None,
927
+ None,
928
+ None,
929
+ )
930
+
931
+ def rms_norm_fn(
932
+ x,
933
+ weight,
934
+ bias,
935
+ residual=None,
936
+ x1=None,
937
+ weight1=None,
938
+ bias1=None,
939
+ eps=1e-6,
940
+ dropout_p=0.0,
941
+ rowscale=None,
942
+ prenorm=False,
943
+ residual_in_fp32=False,
944
+ zero_centered_weight=False,
945
+ return_dropout_mask=False,
946
+ out=None,
947
+ residual_out=None
948
+ ):
949
+ return LayerNormFn.apply(
950
+ x,
951
+ weight,
952
+ bias,
953
+ residual,
954
+ x1,
955
+ weight1,
956
+ bias1,
957
+ eps,
958
+ dropout_p,
959
+ rowscale,
960
+ prenorm,
961
+ residual_in_fp32,
962
+ zero_centered_weight,
963
+ True,
964
+ return_dropout_mask,
965
+ out,
966
+ residual_out
967
+ )
968
+
969
+ class RMSNorm(torch.nn.Module):
970
+ def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, zero_centered_weight=False,
971
+ device=None, dtype=None):
972
+ factory_kwargs = {"device": device, "dtype": dtype}
973
+ super().__init__()
974
+ self.eps = eps
975
+ if dropout_p > 0.0:
976
+ self.drop = torch.nn.Dropout(dropout_p)
977
+ else:
978
+ self.drop = None
979
+ self.zero_centered_weight = zero_centered_weight
980
+ self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
981
+ self.register_parameter("bias", None)
982
+ self.reset_parameters()
983
+
984
+ def reset_parameters(self):
985
+ if not self.zero_centered_weight:
986
+ torch.nn.init.ones_(self.weight)
987
+ else:
988
+ torch.nn.init.zeros_(self.weight)
989
+
990
+ def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
991
+ return rms_norm_fn(
992
+ x,
993
+ self.weight,
994
+ self.bias,
995
+ residual=residual,
996
+ eps=self.eps,
997
+ dropout_p=self.drop.p if self.drop is not None and self.training else 0.0,
998
+ prenorm=prenorm,
999
+ residual_in_fp32=residual_in_fp32,
1000
+ zero_centered_weight=self.zero_centered_weight,
1001
+ )
1002
+ else:
1003
+ from torch.nn import RMSNorm
1004
+ warnings.warn("Cannot import triton, install triton to use fused RMSNorm for better performance")
1005
+
1006
+ def swiglu(x, y):
1007
+ return F.silu(x.float(), inplace=False).to(x.dtype) * y
1008
+
1009
+ logger = logging.get_logger(__name__)
1010
+
1011
+
1012
+ class TimestepEmbedding(nn.Module):
1013
+ def __init__(
1014
+ self,
1015
+ in_channels: int,
1016
+ time_embed_dim: int,
1017
+ act_fn: str = "silu",
1018
+ out_dim: int = None,
1019
+ post_act_fn: Optional[str] = None,
1020
+ cond_proj_dim=None,
1021
+ sample_proj_bias=True,
1022
+ ):
1023
+ super().__init__()
1024
+
1025
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
1026
+
1027
+ if cond_proj_dim is not None:
1028
+ self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
1029
+ else:
1030
+ self.cond_proj = None
1031
+
1032
+ self.act = get_activation(act_fn)
1033
+
1034
+ if out_dim is not None:
1035
+ time_embed_dim_out = out_dim
1036
+ else:
1037
+ time_embed_dim_out = time_embed_dim
1038
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
1039
+
1040
+ if post_act_fn is None:
1041
+ self.post_act = None
1042
+ else:
1043
+ self.post_act = get_activation(post_act_fn)
1044
+
1045
+ self.initialize_weights()
1046
+
1047
+ def initialize_weights(self):
1048
+ nn.init.normal_(self.linear_1.weight, std=0.02)
1049
+ nn.init.zeros_(self.linear_1.bias)
1050
+ nn.init.normal_(self.linear_2.weight, std=0.02)
1051
+ nn.init.zeros_(self.linear_2.bias)
1052
+
1053
+ def forward(self, sample, condition=None):
1054
+ if condition is not None:
1055
+ sample = sample + self.cond_proj(condition)
1056
+ sample = self.linear_1(sample)
1057
+
1058
+ if self.act is not None:
1059
+ sample = self.act(sample)
1060
+
1061
+ sample = self.linear_2(sample)
1062
+
1063
+ if self.post_act is not None:
1064
+ sample = self.post_act(sample)
1065
+ return sample
1066
+
1067
+ def apply_rotary_emb(
1068
+ x: torch.Tensor,
1069
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
1070
+ use_real: bool = True,
1071
+ use_real_unbind_dim: int = -1,
1072
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1073
+ """
1074
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
1075
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
1076
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
1077
+ tensors contain rotary embeddings and are returned as real tensors.
1078
+
1079
+ Args:
1080
+ x (`torch.Tensor`):
1081
+ Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
1082
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
1083
+
1084
+ Returns:
1085
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
1086
+ """
1087
+ if use_real:
1088
+ cos, sin = freqs_cis # [S, D]
1089
+ cos = cos[None, None]
1090
+ sin = sin[None, None]
1091
+ cos, sin = cos.to(x.device), sin.to(x.device)
1092
+
1093
+ if use_real_unbind_dim == -1:
1094
+ # Used for flux, cogvideox, hunyuan-dit
1095
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
1096
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
1097
+ elif use_real_unbind_dim == -2:
1098
+ # Used for Stable Audio, OmniGen and CogView4
1099
+ x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
1100
+ x_rotated = torch.cat([-x_imag, x_real], dim=-1)
1101
+ else:
1102
+ raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
1103
+
1104
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
1105
+
1106
+ return out
1107
+ else:
1108
+ # used for lumina
1109
+ # x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
1110
+ x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], x.shape[-1] // 2, 2))
1111
+ freqs_cis = freqs_cis.unsqueeze(2)
1112
+ x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
1113
+
1114
+ return x_out.type_as(x)
1115
+
1116
+ class OmniGen2RotaryPosEmbed(nn.Module):
1117
+ def __init__(self, theta: int,
1118
+ axes_dim: Tuple[int, int, int],
1119
+ axes_lens: Tuple[int, int, int] = (300, 512, 512),
1120
+ patch_size: int = 2):
1121
+ super().__init__()
1122
+ self.theta = theta
1123
+ self.axes_dim = axes_dim
1124
+ self.axes_lens = axes_lens
1125
+ self.patch_size = patch_size
1126
+
1127
+ @staticmethod
1128
+ def get_freqs_cis(axes_dim: Tuple[int, int, int],
1129
+ axes_lens: Tuple[int, int, int],
1130
+ theta: int) -> List[torch.Tensor]:
1131
+ freqs_cis = []
1132
+ freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
1133
+ for i, (d, e) in enumerate(zip(axes_dim, axes_lens)):
1134
+ emb = get_1d_rotary_pos_embed(d, e, theta=theta, freqs_dtype=freqs_dtype)
1135
+ freqs_cis.append(emb)
1136
+ return freqs_cis
1137
+
1138
+ def _get_freqs_cis(self, freqs_cis, ids: torch.Tensor) -> torch.Tensor:
1139
+ device = ids.device
1140
+ if ids.device.type == "mps":
1141
+ ids = ids.to("cpu")
1142
+
1143
+ result = []
1144
+ for i in range(len(self.axes_dim)):
1145
+ freqs = freqs_cis[i].to(ids.device)
1146
+ index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64)
1147
+ result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index))
1148
+ return torch.cat(result, dim=-1).to(device)
1149
+
1150
+ def forward(
1151
+ self,
1152
+ freqs_cis,
1153
+ attention_mask,
1154
+ l_effective_ref_img_len,
1155
+ l_effective_img_len,
1156
+ ref_img_sizes,
1157
+ img_sizes,
1158
+ device
1159
+ ):
1160
+ batch_size = len(attention_mask)
1161
+ p = self.patch_size
1162
+
1163
+ encoder_seq_len = attention_mask.shape[1]
1164
+ l_effective_cap_len = attention_mask.sum(dim=1).tolist()
1165
+
1166
+ seq_lengths = [cap_len + sum(ref_img_len) + img_len for cap_len, ref_img_len, img_len in zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len)]
1167
+
1168
+ max_seq_len = max(seq_lengths)
1169
+ max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len])
1170
+ max_img_len = max(l_effective_img_len)
1171
+
1172
+ # Create position IDs
1173
+ position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device)
1174
+
1175
+ for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
1176
+ # add text position ids
1177
+ position_ids[i, :cap_seq_len] = repeat(torch.arange(cap_seq_len, dtype=torch.int32, device=device), "l -> l 3")
1178
+
1179
+ pe_shift = cap_seq_len
1180
+ pe_shift_len = cap_seq_len
1181
+
1182
+ if ref_img_sizes[i] is not None:
1183
+ for ref_img_size, ref_img_len in zip(ref_img_sizes[i], l_effective_ref_img_len[i]):
1184
+ H, W = ref_img_size
1185
+ ref_H_tokens, ref_W_tokens = H // p, W // p
1186
+ assert ref_H_tokens * ref_W_tokens == ref_img_len
1187
+ # add image position ids
1188
+
1189
+ row_ids = repeat(torch.arange(ref_H_tokens, dtype=torch.int32, device=device), "h -> h w", w=ref_W_tokens).flatten()
1190
+ col_ids = repeat(torch.arange(ref_W_tokens, dtype=torch.int32, device=device), "w -> h w", h=ref_H_tokens).flatten()
1191
+ position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 0] = pe_shift
1192
+ position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 1] = row_ids
1193
+ position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 2] = col_ids
1194
+
1195
+ pe_shift += max(ref_H_tokens, ref_W_tokens)
1196
+ pe_shift_len += ref_img_len
1197
+
1198
+ H, W = img_sizes[i]
1199
+ H_tokens, W_tokens = H // p, W // p
1200
+ assert H_tokens * W_tokens == l_effective_img_len[i]
1201
+
1202
+ row_ids = repeat(torch.arange(H_tokens, dtype=torch.int32, device=device), "h -> h w", w=W_tokens).flatten()
1203
+ col_ids = repeat(torch.arange(W_tokens, dtype=torch.int32, device=device), "w -> h w", h=H_tokens).flatten()
1204
+
1205
+ assert pe_shift_len + l_effective_img_len[i] == seq_len
1206
+ position_ids[i, pe_shift_len: seq_len, 0] = pe_shift
1207
+ position_ids[i, pe_shift_len: seq_len, 1] = row_ids
1208
+ position_ids[i, pe_shift_len: seq_len, 2] = col_ids
1209
+
1210
+ # Get combined rotary embeddings
1211
+ freqs_cis = self._get_freqs_cis(freqs_cis, position_ids)
1212
+
1213
+ # create separate rotary embeddings for captions and images
1214
+ cap_freqs_cis = torch.zeros(
1215
+ batch_size, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
1216
+ )
1217
+ ref_img_freqs_cis = torch.zeros(
1218
+ batch_size, max_ref_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
1219
+ )
1220
+ img_freqs_cis = torch.zeros(
1221
+ batch_size, max_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
1222
+ )
1223
+
1224
+ for i, (cap_seq_len, ref_img_len, img_len, seq_len) in enumerate(zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, seq_lengths)):
1225
+ cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len]
1226
+ ref_img_freqs_cis[i, :sum(ref_img_len)] = freqs_cis[i, cap_seq_len:cap_seq_len + sum(ref_img_len)]
1227
+ img_freqs_cis[i, :img_len] = freqs_cis[i, cap_seq_len + sum(ref_img_len):cap_seq_len + sum(ref_img_len) + img_len]
1228
+
1229
+ return (
1230
+ cap_freqs_cis,
1231
+ ref_img_freqs_cis,
1232
+ img_freqs_cis,
1233
+ freqs_cis,
1234
+ l_effective_cap_len,
1235
+ seq_lengths,
1236
+ )
1237
+
1238
+
1239
+ class LuminaRMSNormZero(nn.Module):
1240
+ """
1241
+ Norm layer adaptive RMS normalization zero.
1242
+
1243
+ Parameters:
1244
+ embedding_dim (`int`): The size of each embedding vector.
1245
+ """
1246
+
1247
+ def __init__(
1248
+ self,
1249
+ embedding_dim: int,
1250
+ norm_eps: float,
1251
+ norm_elementwise_affine: bool,
1252
+ ):
1253
+ super().__init__()
1254
+ self.silu = nn.SiLU()
1255
+ self.linear = nn.Linear(
1256
+ min(embedding_dim, 1024),
1257
+ 4 * embedding_dim,
1258
+ bias=True,
1259
+ )
1260
+ self.norm = RMSNorm(embedding_dim, eps=norm_eps)
1261
+
1262
+ def forward(
1263
+ self,
1264
+ x: torch.Tensor,
1265
+ emb: Optional[torch.Tensor] = None,
1266
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
1267
+ emb = self.linear(self.silu(emb))
1268
+ scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
1269
+ x = self.norm(x) * (1 + scale_msa[:, None])
1270
+ return x, gate_msa, scale_mlp, gate_mlp
1271
+
1272
+
1273
+ class LuminaLayerNormContinuous(nn.Module):
1274
+ def __init__(
1275
+ self,
1276
+ embedding_dim: int,
1277
+ conditioning_embedding_dim: int,
1278
+ # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
1279
+ # because the output is immediately scaled and shifted by the projected conditioning embeddings.
1280
+ # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
1281
+ # However, this is how it was implemented in the original code, and it's rather likely you should
1282
+ # set `elementwise_affine` to False.
1283
+ elementwise_affine=True,
1284
+ eps=1e-5,
1285
+ bias=True,
1286
+ norm_type="layer_norm",
1287
+ out_dim: Optional[int] = None,
1288
+ ):
1289
+ super().__init__()
1290
+
1291
+ # AdaLN
1292
+ self.silu = nn.SiLU()
1293
+ self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias)
1294
+
1295
+ if norm_type == "layer_norm":
1296
+ self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias)
1297
+ elif norm_type == "rms_norm":
1298
+ self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
1299
+ else:
1300
+ raise ValueError(f"unknown norm_type {norm_type}")
1301
+
1302
+ self.linear_2 = None
1303
+ if out_dim is not None:
1304
+ self.linear_2 = nn.Linear(embedding_dim, out_dim, bias=bias)
1305
+
1306
+ def forward(
1307
+ self,
1308
+ x: torch.Tensor,
1309
+ conditioning_embedding: torch.Tensor,
1310
+ ) -> torch.Tensor:
1311
+ # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
1312
+ emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
1313
+ scale = emb
1314
+ x = self.norm(x) * (1 + scale)[:, None, :]
1315
+
1316
+ if self.linear_2 is not None:
1317
+ x = self.linear_2(x)
1318
+
1319
+ return x
1320
+
1321
+
1322
+ class LuminaFeedForward(nn.Module):
1323
+ r"""
1324
+ A feed-forward layer.
1325
+
1326
+ Parameters:
1327
+ hidden_size (`int`):
1328
+ The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
1329
+ hidden representations.
1330
+ intermediate_size (`int`): The intermediate dimension of the feedforward layer.
1331
+ multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple
1332
+ of this value.
1333
+ ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden
1334
+ dimension. Defaults to None.
1335
+ """
1336
+
1337
+ def __init__(
1338
+ self,
1339
+ dim: int,
1340
+ inner_dim: int,
1341
+ multiple_of: Optional[int] = 256,
1342
+ ffn_dim_multiplier: Optional[float] = None,
1343
+ ):
1344
+ super().__init__()
1345
+
1346
+ self.swiglu = swiglu
1347
+
1348
+ # custom hidden_size factor multiplier
1349
+ if ffn_dim_multiplier is not None:
1350
+ inner_dim = int(ffn_dim_multiplier * inner_dim)
1351
+ inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
1352
+
1353
+ self.linear_1 = nn.Linear(
1354
+ dim,
1355
+ inner_dim,
1356
+ bias=False,
1357
+ )
1358
+ self.linear_2 = nn.Linear(
1359
+ inner_dim,
1360
+ dim,
1361
+ bias=False,
1362
+ )
1363
+ self.linear_3 = nn.Linear(
1364
+ dim,
1365
+ inner_dim,
1366
+ bias=False,
1367
+ )
1368
+
1369
+ def forward(self, x):
1370
+ h1, h2 = self.linear_1(x), self.linear_3(x)
1371
+ return self.linear_2(self.swiglu(h1, h2))
1372
+
1373
+
1374
+ class Lumina2CombinedTimestepCaptionEmbedding(nn.Module):
1375
+ def __init__(
1376
+ self,
1377
+ hidden_size: int = 4096,
1378
+ text_feat_dim: int = 2048,
1379
+ frequency_embedding_size: int = 256,
1380
+ norm_eps: float = 1e-5,
1381
+ timestep_scale: float = 1.0,
1382
+ ) -> None:
1383
+ super().__init__()
1384
+
1385
+ self.time_proj = Timesteps(
1386
+ num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=timestep_scale
1387
+ )
1388
+
1389
+ self.timestep_embedder = TimestepEmbedding(
1390
+ in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024)
1391
+ )
1392
+
1393
+ self.caption_embedder = nn.Sequential(
1394
+ RMSNorm(text_feat_dim, eps=norm_eps),
1395
+ nn.Linear(text_feat_dim, hidden_size, bias=True),
1396
+ )
1397
+
1398
+ self._initialize_weights()
1399
+
1400
+ def _initialize_weights(self):
1401
+ nn.init.trunc_normal_(self.caption_embedder[1].weight, std=0.02)
1402
+ nn.init.zeros_(self.caption_embedder[1].bias)
1403
+
1404
+ def forward(
1405
+ self, timestep: torch.Tensor, text_hidden_states: torch.Tensor, dtype: torch.dtype
1406
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1407
+ timestep_proj = self.time_proj(timestep).to(dtype=dtype)
1408
+ time_embed = self.timestep_embedder(timestep_proj)
1409
+ caption_embed = self.caption_embedder(text_hidden_states)
1410
+ return time_embed, caption_embed
1411
+
1412
+
1413
+ class OmniGen2AttnProcessor:
1414
+ """
1415
+ Processor for implementing scaled dot-product attention.
1416
+
1417
+ This processor is optimized for PyTorch 2.0 and implements:
1418
+ - Flash attention with variable length sequences
1419
+ - Rotary position embeddings (RoPE)
1420
+ - Query-Key normalization
1421
+ - Proportional attention scaling
1422
+
1423
+ Args:
1424
+ None
1425
+
1426
+ Raises:
1427
+ ImportError: If PyTorch version is less than 2.0
1428
+ """
1429
+
1430
+ def __init__(self) -> None:
1431
+ """Initialize the attention processor."""
1432
+ if not hasattr(F, "scaled_dot_product_attention"):
1433
+ raise ImportError(
1434
+ "OmniGen2AttnProcessorFlash2Varlen requires PyTorch 2.0. "
1435
+ "Please upgrade PyTorch to version 2.0 or later."
1436
+ )
1437
+
1438
+ def __call__(
1439
+ self,
1440
+ attn: Attention,
1441
+ hidden_states: torch.Tensor,
1442
+ encoder_hidden_states: torch.Tensor,
1443
+ attention_mask: Optional[torch.Tensor] = None,
1444
+ image_rotary_emb: Optional[torch.Tensor] = None,
1445
+ base_sequence_length: Optional[int] = None,
1446
+ ) -> torch.Tensor:
1447
+ """
1448
+ Process attention computation with flash attention.
1449
+
1450
+ Args:
1451
+ attn: Attention module
1452
+ hidden_states: Hidden states tensor of shape (batch_size, seq_len, hidden_dim)
1453
+ encoder_hidden_states: Encoder hidden states tensor
1454
+ attention_mask: Optional attention mask tensor
1455
+ image_rotary_emb: Optional rotary embeddings for image tokens
1456
+ base_sequence_length: Optional base sequence length for proportional attention
1457
+
1458
+ Returns:
1459
+ torch.Tensor: Processed hidden states after attention computation
1460
+ """
1461
+ batch_size, sequence_length, _ = hidden_states.shape
1462
+
1463
+ # Get Query-Key-Value Pair
1464
+ query = attn.to_q(hidden_states)
1465
+ key = attn.to_k(encoder_hidden_states)
1466
+ value = attn.to_v(encoder_hidden_states)
1467
+
1468
+ query_dim = query.shape[-1]
1469
+ inner_dim = key.shape[-1]
1470
+ head_dim = query_dim // attn.heads
1471
+ dtype = query.dtype
1472
+
1473
+ # Get key-value heads
1474
+ kv_heads = inner_dim // head_dim
1475
+
1476
+ # Reshape tensors for attention computation
1477
+ query = query.view(batch_size, -1, attn.heads, head_dim)
1478
+ key = key.view(batch_size, -1, kv_heads, head_dim)
1479
+ value = value.view(batch_size, -1, kv_heads, head_dim)
1480
+
1481
+ # Apply Query-Key normalization
1482
+ if attn.norm_q is not None:
1483
+ query = attn.norm_q(query)
1484
+ if attn.norm_k is not None:
1485
+ key = attn.norm_k(key)
1486
+
1487
+ # Apply Rotary Position Embeddings
1488
+ if image_rotary_emb is not None:
1489
+ query = apply_rotary_emb(query, image_rotary_emb, use_real=False)
1490
+ key = apply_rotary_emb(key, image_rotary_emb, use_real=False)
1491
+
1492
+ query, key = query.to(dtype), key.to(dtype)
1493
+
1494
+ # Calculate attention scale
1495
+ if base_sequence_length is not None:
1496
+ softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
1497
+ else:
1498
+ softmax_scale = attn.scale
1499
+
1500
+ # scaled_dot_product_attention expects attention_mask shape to be
1501
+ # (batch, heads, source_length, target_length)
1502
+ if attention_mask is not None:
1503
+ attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1)
1504
+
1505
+ query = query.transpose(1, 2)
1506
+ key = key.transpose(1, 2)
1507
+ value = value.transpose(1, 2)
1508
+
1509
+ # explicitly repeat key and value to match query length, otherwise using enable_gqa=True results in MATH backend of sdpa in our test of pytorch2.6
1510
+ key = key.repeat_interleave(query.size(-3) // key.size(-3), -3)
1511
+ value = value.repeat_interleave(query.size(-3) // value.size(-3), -3)
1512
+
1513
+ hidden_states = F.scaled_dot_product_attention(
1514
+ query, key, value, attn_mask=attention_mask, scale=softmax_scale
1515
+ )
1516
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1517
+ hidden_states = hidden_states.type_as(query)
1518
+
1519
+ # Apply output projection
1520
+ hidden_states = attn.to_out[0](hidden_states)
1521
+ hidden_states = attn.to_out[1](hidden_states)
1522
+
1523
+ return hidden_states
1524
+
1525
+ class OmniGen2TransformerBlock(nn.Module):
1526
+ """
1527
+ Transformer block for OmniGen2 model.
1528
+
1529
+ This block implements a transformer layer with:
1530
+ - Multi-head attention with flash attention
1531
+ - Feed-forward network with SwiGLU activation
1532
+ - RMS normalization
1533
+ - Optional modulation for conditional generation
1534
+
1535
+ Args:
1536
+ dim: Dimension of the input and output tensors
1537
+ num_attention_heads: Number of attention heads
1538
+ num_kv_heads: Number of key-value heads
1539
+ multiple_of: Multiple of which the hidden dimension should be
1540
+ ffn_dim_multiplier: Multiplier for the feed-forward network dimension
1541
+ norm_eps: Epsilon value for normalization layers
1542
+ modulation: Whether to use modulation for conditional generation
1543
+ use_fused_rms_norm: Whether to use fused RMS normalization
1544
+ use_fused_swiglu: Whether to use fused SwiGLU activation
1545
+ """
1546
+
1547
+ def __init__(
1548
+ self,
1549
+ dim: int,
1550
+ num_attention_heads: int,
1551
+ num_kv_heads: int,
1552
+ multiple_of: int,
1553
+ ffn_dim_multiplier: float,
1554
+ norm_eps: float,
1555
+ modulation: bool = True,
1556
+ ) -> None:
1557
+ """Initialize the transformer block."""
1558
+ super().__init__()
1559
+ self.head_dim = dim // num_attention_heads
1560
+ self.modulation = modulation
1561
+
1562
+ # Initialize attention layer
1563
+ self.attn = Attention(
1564
+ query_dim=dim,
1565
+ cross_attention_dim=None,
1566
+ dim_head=dim // num_attention_heads,
1567
+ qk_norm="rms_norm",
1568
+ heads=num_attention_heads,
1569
+ kv_heads=num_kv_heads,
1570
+ eps=1e-5,
1571
+ bias=False,
1572
+ out_bias=False,
1573
+ processor=OmniGen2AttnProcessor(),
1574
+ )
1575
+
1576
+ # Initialize feed-forward network
1577
+ self.feed_forward = LuminaFeedForward(
1578
+ dim=dim,
1579
+ inner_dim=4 * dim,
1580
+ multiple_of=multiple_of,
1581
+ ffn_dim_multiplier=ffn_dim_multiplier,
1582
+ )
1583
+
1584
+ # Initialize normalization layers
1585
+ if modulation:
1586
+ self.norm1 = LuminaRMSNormZero(
1587
+ embedding_dim=dim,
1588
+ norm_eps=norm_eps,
1589
+ norm_elementwise_affine=True,
1590
+ )
1591
+ else:
1592
+ self.norm1 = RMSNorm(dim, eps=norm_eps)
1593
+
1594
+ self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
1595
+ self.norm2 = RMSNorm(dim, eps=norm_eps)
1596
+ self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
1597
+
1598
+ self.initialize_weights()
1599
+
1600
+ def initialize_weights(self) -> None:
1601
+ """
1602
+ Initialize the weights of the transformer block.
1603
+
1604
+ Uses Xavier uniform initialization for linear layers and zero initialization for biases.
1605
+ """
1606
+ nn.init.xavier_uniform_(self.attn.to_q.weight)
1607
+ nn.init.xavier_uniform_(self.attn.to_k.weight)
1608
+ nn.init.xavier_uniform_(self.attn.to_v.weight)
1609
+ nn.init.xavier_uniform_(self.attn.to_out[0].weight)
1610
+
1611
+ nn.init.xavier_uniform_(self.feed_forward.linear_1.weight)
1612
+ nn.init.xavier_uniform_(self.feed_forward.linear_2.weight)
1613
+ nn.init.xavier_uniform_(self.feed_forward.linear_3.weight)
1614
+
1615
+ if self.modulation:
1616
+ nn.init.zeros_(self.norm1.linear.weight)
1617
+ nn.init.zeros_(self.norm1.linear.bias)
1618
+
1619
+ def forward(
1620
+ self,
1621
+ hidden_states: torch.Tensor,
1622
+ attention_mask: torch.Tensor,
1623
+ image_rotary_emb: torch.Tensor,
1624
+ temb: Optional[torch.Tensor] = None,
1625
+ ) -> torch.Tensor:
1626
+ """
1627
+ Forward pass of the transformer block.
1628
+
1629
+ Args:
1630
+ hidden_states: Input hidden states tensor
1631
+ attention_mask: Attention mask tensor
1632
+ image_rotary_emb: Rotary embeddings for image tokens
1633
+ temb: Optional timestep embedding tensor
1634
+
1635
+ Returns:
1636
+ torch.Tensor: Output hidden states after transformer block processing
1637
+ """
1638
+ if self.modulation:
1639
+ if temb is None:
1640
+ raise ValueError("temb must be provided when modulation is enabled")
1641
+
1642
+ norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
1643
+ attn_output = self.attn(
1644
+ hidden_states=norm_hidden_states,
1645
+ encoder_hidden_states=norm_hidden_states,
1646
+ attention_mask=attention_mask,
1647
+ image_rotary_emb=image_rotary_emb,
1648
+ )
1649
+ hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
1650
+ mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
1651
+ hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
1652
+ else:
1653
+ norm_hidden_states = self.norm1(hidden_states)
1654
+ attn_output = self.attn(
1655
+ hidden_states=norm_hidden_states,
1656
+ encoder_hidden_states=norm_hidden_states,
1657
+ attention_mask=attention_mask,
1658
+ image_rotary_emb=image_rotary_emb,
1659
+ )
1660
+ hidden_states = hidden_states + self.norm2(attn_output)
1661
+ mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
1662
+ hidden_states = hidden_states + self.ffn_norm2(mlp_output)
1663
+
1664
+ return hidden_states
1665
+
1666
+
1667
+ class OmniGen2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
1668
+ """
1669
+ OmniGen2 Transformer 2D Model.
1670
+
1671
+ A transformer-based diffusion model for image generation with:
1672
+ - Patch-based image processing
1673
+ - Rotary position embeddings
1674
+ - Multi-head attention
1675
+ - Conditional generation support
1676
+
1677
+ Args:
1678
+ patch_size: Size of image patches
1679
+ in_channels: Number of input channels
1680
+ out_channels: Number of output channels (defaults to in_channels)
1681
+ hidden_size: Size of hidden layers
1682
+ num_layers: Number of transformer layers
1683
+ num_refiner_layers: Number of refiner layers
1684
+ num_attention_heads: Number of attention heads
1685
+ num_kv_heads: Number of key-value heads
1686
+ multiple_of: Multiple of which the hidden dimension should be
1687
+ ffn_dim_multiplier: Multiplier for feed-forward network dimension
1688
+ norm_eps: Epsilon value for normalization layers
1689
+ axes_dim_rope: Dimensions for rotary position embeddings
1690
+ axes_lens: Lengths for rotary position embeddings
1691
+ text_feat_dim: Dimension of text features
1692
+ timestep_scale: Scale factor for timestep embeddings
1693
+ use_fused_rms_norm: Whether to use fused RMS normalization
1694
+ use_fused_swiglu: Whether to use fused SwiGLU activation
1695
+ """
1696
+
1697
+ _supports_gradient_checkpointing = True
1698
+ _no_split_modules = ["Omnigen2TransformerBlock"]
1699
+ _skip_layerwise_casting_patterns = ["x_embedder", "norm"]
1700
+
1701
+ @register_to_config
1702
+ def __init__(
1703
+ self,
1704
+ patch_size: int = 2,
1705
+ in_channels: int = 16,
1706
+ out_channels: Optional[int] = None,
1707
+ hidden_size: int = 2304,
1708
+ num_layers: int = 26,
1709
+ num_refiner_layers: int = 2,
1710
+ num_attention_heads: int = 24,
1711
+ num_kv_heads: int = 8,
1712
+ multiple_of: int = 256,
1713
+ ffn_dim_multiplier: Optional[float] = None,
1714
+ norm_eps: float = 1e-5,
1715
+ axes_dim_rope: Tuple[int, int, int] = (32, 32, 32),
1716
+ axes_lens: Tuple[int, int, int] = (300, 512, 512),
1717
+ text_feat_dim: int = 1024,
1718
+ timestep_scale: float = 1.0,
1719
+ ) -> None:
1720
+ """Initialize the OmniGen2 transformer model."""
1721
+ super().__init__()
1722
+
1723
+ # Validate configuration
1724
+ if (hidden_size // num_attention_heads) != sum(axes_dim_rope):
1725
+ raise ValueError(
1726
+ f"hidden_size // num_attention_heads ({hidden_size // num_attention_heads}) "
1727
+ f"must equal sum(axes_dim_rope) ({sum(axes_dim_rope)})"
1728
+ )
1729
+
1730
+ self.out_channels = out_channels or in_channels
1731
+
1732
+ # Initialize embeddings
1733
+ self.rope_embedder = OmniGen2RotaryPosEmbed(
1734
+ theta=10000,
1735
+ axes_dim=axes_dim_rope,
1736
+ axes_lens=axes_lens,
1737
+ patch_size=patch_size,
1738
+ )
1739
+
1740
+ self.x_embedder = nn.Linear(
1741
+ in_features=patch_size * patch_size * in_channels,
1742
+ out_features=hidden_size,
1743
+ )
1744
+
1745
+ self.ref_image_patch_embedder = nn.Linear(
1746
+ in_features=patch_size * patch_size * in_channels,
1747
+ out_features=hidden_size,
1748
+ )
1749
+
1750
+ self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding(
1751
+ hidden_size=hidden_size,
1752
+ text_feat_dim=text_feat_dim,
1753
+ norm_eps=norm_eps,
1754
+ timestep_scale=timestep_scale,
1755
+ )
1756
+
1757
+ # Initialize transformer blocks
1758
+ self.noise_refiner = nn.ModuleList([
1759
+ OmniGen2TransformerBlock(
1760
+ hidden_size,
1761
+ num_attention_heads,
1762
+ num_kv_heads,
1763
+ multiple_of,
1764
+ ffn_dim_multiplier,
1765
+ norm_eps,
1766
+ modulation=True,
1767
+ )
1768
+ for _ in range(num_refiner_layers)
1769
+ ])
1770
+
1771
+ self.ref_image_refiner = nn.ModuleList([
1772
+ OmniGen2TransformerBlock(
1773
+ hidden_size,
1774
+ num_attention_heads,
1775
+ num_kv_heads,
1776
+ multiple_of,
1777
+ ffn_dim_multiplier,
1778
+ norm_eps,
1779
+ modulation=True,
1780
+ )
1781
+ for _ in range(num_refiner_layers)
1782
+ ])
1783
+
1784
+ self.context_refiner = nn.ModuleList(
1785
+ [
1786
+ OmniGen2TransformerBlock(
1787
+ hidden_size,
1788
+ num_attention_heads,
1789
+ num_kv_heads,
1790
+ multiple_of,
1791
+ ffn_dim_multiplier,
1792
+ norm_eps,
1793
+ modulation=False,
1794
+ )
1795
+ for _ in range(num_refiner_layers)
1796
+ ]
1797
+ )
1798
+
1799
+ # 3. Transformer blocks
1800
+ self.layers = nn.ModuleList(
1801
+ [
1802
+ OmniGen2TransformerBlock(
1803
+ hidden_size,
1804
+ num_attention_heads,
1805
+ num_kv_heads,
1806
+ multiple_of,
1807
+ ffn_dim_multiplier,
1808
+ norm_eps,
1809
+ modulation=True,
1810
+ )
1811
+ for _ in range(num_layers)
1812
+ ]
1813
+ )
1814
+
1815
+ # 4. Output norm & projection
1816
+ self.norm_out = LuminaLayerNormContinuous(
1817
+ embedding_dim=hidden_size,
1818
+ conditioning_embedding_dim=min(hidden_size, 1024),
1819
+ elementwise_affine=False,
1820
+ eps=1e-6,
1821
+ bias=True,
1822
+ out_dim=patch_size * patch_size * self.out_channels,
1823
+ )
1824
+
1825
+ # Add learnable embeddings to distinguish different images
1826
+ self.image_index_embedding = nn.Parameter(torch.randn(5, hidden_size)) # support max 5 ref images
1827
+
1828
+ self.gradient_checkpointing = False
1829
+
1830
+ self.initialize_weights()
1831
+
1832
+ def initialize_weights(self) -> None:
1833
+ """
1834
+ Initialize the weights of the model.
1835
+
1836
+ Uses Xavier uniform initialization for linear layers.
1837
+ """
1838
+ nn.init.xavier_uniform_(self.x_embedder.weight)
1839
+ nn.init.constant_(self.x_embedder.bias, 0.0)
1840
+
1841
+ nn.init.xavier_uniform_(self.ref_image_patch_embedder.weight)
1842
+ nn.init.constant_(self.ref_image_patch_embedder.bias, 0.0)
1843
+
1844
+ nn.init.zeros_(self.norm_out.linear_1.weight)
1845
+ nn.init.zeros_(self.norm_out.linear_1.bias)
1846
+ nn.init.zeros_(self.norm_out.linear_2.weight)
1847
+ nn.init.zeros_(self.norm_out.linear_2.bias)
1848
+
1849
+ nn.init.normal_(self.image_index_embedding, std=0.02)
1850
+
1851
+ def img_patch_embed_and_refine(
1852
+ self,
1853
+ hidden_states,
1854
+ ref_image_hidden_states,
1855
+ padded_img_mask,
1856
+ padded_ref_img_mask,
1857
+ noise_rotary_emb,
1858
+ ref_img_rotary_emb,
1859
+ l_effective_ref_img_len,
1860
+ l_effective_img_len,
1861
+ temb
1862
+ ):
1863
+ batch_size = len(hidden_states)
1864
+ max_combined_img_len = max([img_len + sum(ref_img_len) for img_len, ref_img_len in zip(l_effective_img_len, l_effective_ref_img_len)])
1865
+
1866
+ hidden_states = self.x_embedder(hidden_states)
1867
+ ref_image_hidden_states = self.ref_image_patch_embedder(ref_image_hidden_states)
1868
+
1869
+ for i in range(batch_size):
1870
+ shift = 0
1871
+ for j, ref_img_len in enumerate(l_effective_ref_img_len[i]):
1872
+ ref_image_hidden_states[i, shift:shift + ref_img_len, :] = ref_image_hidden_states[i, shift:shift + ref_img_len, :] + self.image_index_embedding[j]
1873
+ shift += ref_img_len
1874
+
1875
+ for layer in self.noise_refiner:
1876
+ hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb)
1877
+
1878
+ flat_l_effective_ref_img_len = list(itertools.chain(*l_effective_ref_img_len))
1879
+ num_ref_images = len(flat_l_effective_ref_img_len)
1880
+ max_ref_img_len = max(flat_l_effective_ref_img_len)
1881
+
1882
+ batch_ref_img_mask = ref_image_hidden_states.new_zeros(num_ref_images, max_ref_img_len, dtype=torch.bool)
1883
+ batch_ref_image_hidden_states = ref_image_hidden_states.new_zeros(num_ref_images, max_ref_img_len, self.config.hidden_size)
1884
+ batch_ref_img_rotary_emb = hidden_states.new_zeros(num_ref_images, max_ref_img_len, ref_img_rotary_emb.shape[-1], dtype=ref_img_rotary_emb.dtype)
1885
+ batch_temb = temb.new_zeros(num_ref_images, *temb.shape[1:], dtype=temb.dtype)
1886
+
1887
+ # sequence of ref imgs to batch
1888
+ idx = 0
1889
+ for i in range(batch_size):
1890
+ shift = 0
1891
+ for ref_img_len in l_effective_ref_img_len[i]:
1892
+ batch_ref_img_mask[idx, :ref_img_len] = True
1893
+ batch_ref_image_hidden_states[idx, :ref_img_len] = ref_image_hidden_states[i, shift:shift + ref_img_len]
1894
+ batch_ref_img_rotary_emb[idx, :ref_img_len] = ref_img_rotary_emb[i, shift:shift + ref_img_len]
1895
+ batch_temb[idx] = temb[i]
1896
+ shift += ref_img_len
1897
+ idx += 1
1898
+
1899
+ # refine ref imgs separately
1900
+ for layer in self.ref_image_refiner:
1901
+ batch_ref_image_hidden_states = layer(batch_ref_image_hidden_states, batch_ref_img_mask, batch_ref_img_rotary_emb, batch_temb)
1902
+
1903
+ # batch of ref imgs to sequence
1904
+ idx = 0
1905
+ for i in range(batch_size):
1906
+ shift = 0
1907
+ for ref_img_len in l_effective_ref_img_len[i]:
1908
+ ref_image_hidden_states[i, shift:shift + ref_img_len] = batch_ref_image_hidden_states[idx, :ref_img_len]
1909
+ shift += ref_img_len
1910
+ idx += 1
1911
+
1912
+ combined_img_hidden_states = hidden_states.new_zeros(batch_size, max_combined_img_len, self.config.hidden_size)
1913
+ for i, (ref_img_len, img_len) in enumerate(zip(l_effective_ref_img_len, l_effective_img_len)):
1914
+ combined_img_hidden_states[i, :sum(ref_img_len)] = ref_image_hidden_states[i, :sum(ref_img_len)]
1915
+ combined_img_hidden_states[i, sum(ref_img_len):sum(ref_img_len) + img_len] = hidden_states[i, :img_len]
1916
+
1917
+ return combined_img_hidden_states
1918
+
1919
+ def flat_and_pad_to_seq(self, hidden_states, ref_image_hidden_states):
1920
+ batch_size = len(hidden_states)
1921
+ p = self.config.patch_size
1922
+ device = hidden_states[0].device
1923
+
1924
+ img_sizes = [(img.size(1), img.size(2)) for img in hidden_states]
1925
+ l_effective_img_len = [(H // p) * (W // p) for (H, W) in img_sizes]
1926
+
1927
+ if ref_image_hidden_states is not None:
1928
+ ref_img_sizes = [[(img.size(1), img.size(2)) for img in imgs] if imgs is not None else None for imgs in ref_image_hidden_states]
1929
+ l_effective_ref_img_len = [[(ref_img_size[0] // p) * (ref_img_size[1] // p) for ref_img_size in _ref_img_sizes] if _ref_img_sizes is not None else [0] for _ref_img_sizes in ref_img_sizes]
1930
+ else:
1931
+ ref_img_sizes = [None for _ in range(batch_size)]
1932
+ l_effective_ref_img_len = [[0] for _ in range(batch_size)]
1933
+
1934
+ max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len])
1935
+ max_img_len = max(l_effective_img_len)
1936
+
1937
+ # ref image patch embeddings
1938
+ flat_ref_img_hidden_states = []
1939
+ for i in range(batch_size):
1940
+ if ref_img_sizes[i] is not None:
1941
+ imgs = []
1942
+ for ref_img in ref_image_hidden_states[i]:
1943
+ C, H, W = ref_img.size()
1944
+ ref_img = rearrange(ref_img, 'c (h p1) (w p2) -> (h w) (p1 p2 c)', p1=p, p2=p)
1945
+ imgs.append(ref_img)
1946
+
1947
+ img = torch.cat(imgs, dim=0)
1948
+ flat_ref_img_hidden_states.append(img)
1949
+ else:
1950
+ flat_ref_img_hidden_states.append(None)
1951
+
1952
+ # image patch embeddings
1953
+ flat_hidden_states = []
1954
+ for i in range(batch_size):
1955
+ img = hidden_states[i]
1956
+ C, H, W = img.size()
1957
+
1958
+ img = rearrange(img, 'c (h p1) (w p2) -> (h w) (p1 p2 c)', p1=p, p2=p)
1959
+ flat_hidden_states.append(img)
1960
+
1961
+ padded_ref_img_hidden_states = torch.zeros(batch_size, max_ref_img_len, flat_hidden_states[0].shape[-1], device=device, dtype=flat_hidden_states[0].dtype)
1962
+ padded_ref_img_mask = torch.zeros(batch_size, max_ref_img_len, dtype=torch.bool, device=device)
1963
+ for i in range(batch_size):
1964
+ if ref_img_sizes[i] is not None:
1965
+ padded_ref_img_hidden_states[i, :sum(l_effective_ref_img_len[i])] = flat_ref_img_hidden_states[i]
1966
+ padded_ref_img_mask[i, :sum(l_effective_ref_img_len[i])] = True
1967
+
1968
+ padded_hidden_states = torch.zeros(batch_size, max_img_len, flat_hidden_states[0].shape[-1], device=device, dtype=flat_hidden_states[0].dtype)
1969
+ padded_img_mask = torch.zeros(batch_size, max_img_len, dtype=torch.bool, device=device)
1970
+ for i in range(batch_size):
1971
+ padded_hidden_states[i, :l_effective_img_len[i]] = flat_hidden_states[i]
1972
+ padded_img_mask[i, :l_effective_img_len[i]] = True
1973
+
1974
+ return (
1975
+ padded_hidden_states,
1976
+ padded_ref_img_hidden_states,
1977
+ padded_img_mask,
1978
+ padded_ref_img_mask,
1979
+ l_effective_ref_img_len,
1980
+ l_effective_img_len,
1981
+ ref_img_sizes,
1982
+ img_sizes,
1983
+ )
1984
+
1985
+ def forward(
1986
+ self,
1987
+ hidden_states: Union[torch.Tensor, List[torch.Tensor]],
1988
+ timestep: torch.Tensor,
1989
+ text_hidden_states: torch.Tensor,
1990
+ freqs_cis: torch.Tensor,
1991
+ text_attention_mask: torch.Tensor,
1992
+ ref_image_hidden_states: Optional[List[List[torch.Tensor]]] = None,
1993
+ attention_kwargs: Optional[Dict[str, Any]] = None,
1994
+ return_dict: bool = False,
1995
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
1996
+ if attention_kwargs is not None:
1997
+ attention_kwargs = attention_kwargs.copy()
1998
+ lora_scale = attention_kwargs.pop("scale", 1.0)
1999
+ else:
2000
+ lora_scale = 1.0
2001
+
2002
+ if USE_PEFT_BACKEND:
2003
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
2004
+ scale_lora_layers(self, lora_scale)
2005
+ else:
2006
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
2007
+ logger.warning(
2008
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
2009
+ )
2010
+
2011
+ # 1. Condition, positional & patch embedding
2012
+ batch_size = len(hidden_states)
2013
+ is_hidden_states_tensor = isinstance(hidden_states, torch.Tensor)
2014
+
2015
+ if is_hidden_states_tensor:
2016
+ assert hidden_states.ndim == 4
2017
+ hidden_states = [_hidden_states for _hidden_states in hidden_states]
2018
+
2019
+ device = hidden_states[0].device
2020
+
2021
+ temb, text_hidden_states = self.time_caption_embed(timestep, text_hidden_states, hidden_states[0].dtype)
2022
+
2023
+ (
2024
+ hidden_states,
2025
+ ref_image_hidden_states,
2026
+ img_mask,
2027
+ ref_img_mask,
2028
+ l_effective_ref_img_len,
2029
+ l_effective_img_len,
2030
+ ref_img_sizes,
2031
+ img_sizes,
2032
+ ) = self.flat_and_pad_to_seq(hidden_states, ref_image_hidden_states)
2033
+
2034
+ (
2035
+ context_rotary_emb,
2036
+ ref_img_rotary_emb,
2037
+ noise_rotary_emb,
2038
+ rotary_emb,
2039
+ encoder_seq_lengths,
2040
+ seq_lengths,
2041
+ ) = self.rope_embedder(
2042
+ freqs_cis,
2043
+ text_attention_mask,
2044
+ l_effective_ref_img_len,
2045
+ l_effective_img_len,
2046
+ ref_img_sizes,
2047
+ img_sizes,
2048
+ device,
2049
+ )
2050
+
2051
+ # 2. Context refinement
2052
+ for layer in self.context_refiner:
2053
+ text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb)
2054
+
2055
+ combined_img_hidden_states = self.img_patch_embed_and_refine(
2056
+ hidden_states,
2057
+ ref_image_hidden_states,
2058
+ img_mask,
2059
+ ref_img_mask,
2060
+ noise_rotary_emb,
2061
+ ref_img_rotary_emb,
2062
+ l_effective_ref_img_len,
2063
+ l_effective_img_len,
2064
+ temb,
2065
+ )
2066
+
2067
+ # 3. Joint Transformer blocks
2068
+ max_seq_len = max(seq_lengths)
2069
+
2070
+ attention_mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool)
2071
+ joint_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size)
2072
+ for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
2073
+ attention_mask[i, :seq_len] = True
2074
+ joint_hidden_states[i, :encoder_seq_len] = text_hidden_states[i, :encoder_seq_len]
2075
+ joint_hidden_states[i, encoder_seq_len:seq_len] = combined_img_hidden_states[i, :seq_len - encoder_seq_len]
2076
+
2077
+ hidden_states = joint_hidden_states
2078
+
2079
+ for layer_idx, layer in enumerate(self.layers):
2080
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
2081
+ hidden_states = self._gradient_checkpointing_func(
2082
+ layer, hidden_states, attention_mask, rotary_emb, temb
2083
+ )
2084
+ else:
2085
+ hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb)
2086
+
2087
+ # 4. Output norm & projection
2088
+ hidden_states = self.norm_out(hidden_states, temb)
2089
+
2090
+ p = self.config.patch_size
2091
+ output = []
2092
+ for i, (img_size, img_len, seq_len) in enumerate(zip(img_sizes, l_effective_img_len, seq_lengths)):
2093
+ height, width = img_size
2094
+ output.append(rearrange(hidden_states[i][seq_len - img_len:seq_len], '(h w) (p1 p2 c) -> c (h p1) (w p2)', h=height // p, w=width // p, p1=p, p2=p))
2095
+ if is_hidden_states_tensor:
2096
+ output = torch.stack(output, dim=0)
2097
+
2098
+ if USE_PEFT_BACKEND:
2099
+ # remove `lora_scale` from each PEFT layer
2100
+ unscale_lora_layers(self, lora_scale)
2101
+
2102
+ if not return_dict:
2103
+ return output
2104
+ return Transformer2DModelOutput(sample=output)
vae/config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.33.1",
4
+ "_name_or_path": "/share_2/luoxin/modelscope/hub/models/FLUX.1-dev",
5
+ "act_fn": "silu",
6
+ "block_out_channels": [
7
+ 128,
8
+ 256,
9
+ 512,
10
+ 512
11
+ ],
12
+ "down_block_types": [
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D",
16
+ "DownEncoderBlock2D"
17
+ ],
18
+ "force_upcast": true,
19
+ "in_channels": 3,
20
+ "latent_channels": 16,
21
+ "latents_mean": null,
22
+ "latents_std": null,
23
+ "layers_per_block": 2,
24
+ "mid_block_add_attention": true,
25
+ "norm_num_groups": 32,
26
+ "out_channels": 3,
27
+ "sample_size": 1024,
28
+ "scaling_factor": 0.3611,
29
+ "shift_factor": 0.1159,
30
+ "up_block_types": [
31
+ "UpDecoderBlock2D",
32
+ "UpDecoderBlock2D",
33
+ "UpDecoderBlock2D",
34
+ "UpDecoderBlock2D"
35
+ ],
36
+ "use_post_quant_conv": false,
37
+ "use_quant_conv": false
38
+ }
vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8c717328c8ad41faab2ccfd52ae17332505c6833cf176aad56e7b58f2c4d4c94
3
+ size 335306212