Safetensors
roman-bachmann commited on
Commit
47410c9
·
0 Parent(s):

Initial commit

Browse files
Files changed (4) hide show
  1. .gitattributes +35 -0
  2. README.md +79 -0
  3. config.json +297 -0
  4. model.safetensors +3 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apple-amlr
3
+ ---
4
+
5
+ # FlexTok: Resampling Images into 1D Token Sequences of Flexible Length
6
+
7
+ [`Website`](https://flextok.epfl.ch) | [`arXiv`](https://arxiv.org/abs/2502.13967) | [`GitHub`](https://github.com/apple/ml-flextok) | [`🤗 Demo`](https://huggingface.co/spaces/EPFL-VILAB/FlexTok) | [`BibTeX`](#citation)
8
+
9
+ Official implementation and pre-trained models for: <br>
10
+ [**FlexTok: Resampling Images into 1D Token Sequences of Flexible Length**](https://arxiv.org/abs/2502.13967), arXiv 2025 <br>
11
+ *[Roman Bachmann](https://roman-bachmann.github.io/)\*, [Jesse Allardice](https://github.com/JesseAllardice)\*, [David Mizrahi](https://dmizrahi.com/)\*, [Enrico Fini](https://scholar.google.com/citations?user=OQMtSKIAAAAJ), [Oğuzhan Fatih Kar](https://ofkar.github.io/), [Elmira Amirloo](https://elamirloo.github.io/), [Alaaeldin El-Nouby](https://aelnouby.github.io/), [Amir Zamir](https://vilab.epfl.ch/zamir/), [Afshin Dehghan](https://scholar.google.com/citations?user=wcX-UW4AAAAJ)*
12
+
13
+
14
+ ## Installation
15
+ For install instructions, please see https://github.com/apple/ml-flextok.
16
+
17
+
18
+ ## Usage
19
+
20
+ To load the `FlexTok d18-d28 ImageNet-1k` model directly from HuggingFace Hub, call:
21
+ ```python
22
+ from flextok.flextok_wrapper import FlexTokFromHub
23
+ model = FlexTokFromHub.from_pretrained('EPFL-VILAB/flextok_d18_d28_in1k').eval()
24
+ ```
25
+
26
+ The model can also be loaded by downloading the `model.safetensors` checkpoint in this repository manually and loading it using our helper functions:
27
+ ```python
28
+ from hydra.utils import instantiate
29
+ from flextok.utils.checkpoint import load_safetensors
30
+
31
+ ckpt, config = load_safetensors('/path/to/model.safetensors')
32
+ model = instantiate(config).eval()
33
+ model.load_state_dict(ckpt)
34
+ ```
35
+
36
+ After loading a FlexTok model, image batches can be encoded using:
37
+ ```python
38
+ from flextok.utils.demo import imgs_from_urls
39
+ # Load example images of shape (B, 3, 256, 256), normalized to [-1,1]
40
+ imgs = imgs_from_urls(urls=['https://storage.googleapis.com/flextok_site/nb_demo_images/0.png'])
41
+
42
+ # tokens_list is a list of [1, 256] discrete token sequences
43
+ tokens_list = model.tokenize(imgs)
44
+ ```
45
+
46
+ The list of token sequences can be truncated in a nested fashion:
47
+ ```python
48
+ k_keep = 64 # For example, only keep the first 64 out of 256 tokens
49
+ tokens_list = [t[:,:k_keep] for t in tokens_list]
50
+ ```
51
+
52
+ To decode the tokens with FlexTok's rectified flow decoder, call:
53
+ ```python
54
+ # tokens_list is a list of [1, l] discrete token sequences, with l <= 256
55
+ # reconst is a [B, 3, 256, 256] tensor, normalized to [-1,1]
56
+ reconst = model.detokenize(
57
+ tokens_list,
58
+ timesteps=20, # Number of denoising steps
59
+ guidance_scale=7.5, # Classifier-free guidance scale
60
+ perform_norm_guidance=True, # See https://arxiv.org/abs/2410.02416
61
+ )
62
+ ```
63
+
64
+
65
+ ## Citation
66
+
67
+ If you find this repository helpful, please consider citing our work:
68
+ ```
69
+ @article{flextok,
70
+ title={{FlexTok}: Resampling Images into 1D Token Sequences of Flexible Length},
71
+ author={Roman Bachmann and Jesse Allardice and David Mizrahi and Enrico Fini and O{\u{g}}uzhan Fatih Kar and Elmira Amirloo and Alaaeldin El-Nouby and Amir Zamir and Afshin Dehghan},
72
+ journal={arXiv 2025},
73
+ year={2025},
74
+ }
75
+ ```
76
+
77
+ ## License
78
+
79
+ The model weights in this repository are released under the Apple Model License for Research.
config.json ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vae": {
3
+ "_target_": "flextok.vae_wrapper.StableDiffusionVAE",
4
+ "images_read_key": "rgb",
5
+ "vae_latents_read_key": "vae_latents_reconst",
6
+ "vae_latents_write_key": "vae_latents",
7
+ "images_reconst_write_key": "rgb_reconst",
8
+ "vae_kl_loss_write_key": "kl_loss",
9
+ "dtype_override": null,
10
+ "sample_posterior": true,
11
+ "compile_encode_fn": false,
12
+ "force_vae_encode": true,
13
+ "latent_channels": 16,
14
+ "scaling_factor": 0.88
15
+ },
16
+ "regularizer": {
17
+ "_target_": "flextok.regularizers.quantize_fsq.FSQ",
18
+ "latents_read_key": "enc_registers",
19
+ "quants_write_key": "enc_registers_quant",
20
+ "tokens_write_key": "tokens",
21
+ "levels": [
22
+ 8,
23
+ 8,
24
+ 8,
25
+ 5,
26
+ 5,
27
+ 5
28
+ ],
29
+ "drop_quant_p": 0.0,
30
+ "packed_call": false
31
+ },
32
+ "encoder": {
33
+ "_target_": "flextok.model.utils.wrappers.SequentialModuleDictWrapper",
34
+ "module_dict": {
35
+ "enc_channels_to_last": {
36
+ "_target_": "flextok.model.utils.dict_ops.PerSampleOp",
37
+ "read_key": "vae_latents",
38
+ "write_key": "vae_latents_bhwc",
39
+ "per_sample_op": {
40
+ "_target_": "flextok.model.utils.dict_ops.channels_first_to_last",
41
+ "_partial_": true
42
+ }
43
+ },
44
+ "enc_patch_emb": {
45
+ "_target_": "flextok.model.preprocessors.patching.PatchEmbedder",
46
+ "input_tensor_list_read_key": "vae_latents_bhwc",
47
+ "patches_list_write_key": "enc_vae_latents_patched",
48
+ "n_patches_write_key": "enc_n_patches",
49
+ "channels_in": 16,
50
+ "dim": 1152,
51
+ "patch_sizes": [
52
+ 2,
53
+ 2
54
+ ],
55
+ "flatten_patches": false
56
+ },
57
+ "enc_posemb_module": {
58
+ "_target_": "flextok.model.utils.posembs.PositionalEmbeddingAdder",
59
+ "read_key": "enc_vae_latents_patched",
60
+ "write_key": "enc_vae_latents_patched",
61
+ "dim": 1152,
62
+ "max_sizes": [
63
+ 16,
64
+ 16
65
+ ],
66
+ "posemb_type": "sincos",
67
+ "posemb_scaling": "absolute"
68
+ },
69
+ "enc_register_module": {
70
+ "_target_": "flextok.model.preprocessors.registers.Registers1D",
71
+ "input_tensor_list_read_key": "enc_vae_latents_patched",
72
+ "register_sizes_read_write_key": "register_sizes",
73
+ "registers_write_key": "enc_registers",
74
+ "dim": 1152,
75
+ "n_min": 256,
76
+ "n_max": 256,
77
+ "size_sampling_mode": "uniform",
78
+ "ordering_mode": "nested"
79
+ },
80
+ "enc_seq_packer": {
81
+ "_target_": "flextok.model.preprocessors.flex_seq_packing.BlockWiseSequencePacker",
82
+ "input_list_read_keys": [
83
+ "enc_vae_latents_patched",
84
+ "enc_registers"
85
+ ],
86
+ "packed_seq_write_key": "enc_packed_seq",
87
+ "block_mask_write_key": "enc_block_mask",
88
+ "inner_packed_shapes_write_key": "enc_ps_inner",
89
+ "outer_packed_shapes_write_key": "enc_ps_outer",
90
+ "mask_mode": "causal_last",
91
+ "pad_to_multiple": 128
92
+ },
93
+ "enc_transformer": {
94
+ "_target_": "flextok.model.trunks.transformers.FlexTransformer",
95
+ "input_seq_read_key": "enc_packed_seq",
96
+ "output_seq_write_key": "enc_packed_seq",
97
+ "dim": 1152,
98
+ "depth": 18,
99
+ "block_mask_read_key": "enc_block_mask",
100
+ "use_act_checkpoint": true
101
+ },
102
+ "enc_unpacker": {
103
+ "_target_": "flextok.model.postprocessors.seq_unpacking.SequenceUnpacker",
104
+ "packed_seq_read_key": "enc_packed_seq",
105
+ "inner_seq_write_keys": [
106
+ "enc_vae_latents_patched",
107
+ "enc_registers"
108
+ ],
109
+ "inner_packed_shapes_read_key": "enc_ps_inner",
110
+ "outer_packed_shapes_read_key": "enc_ps_outer"
111
+ },
112
+ "enc_to_latents": {
113
+ "_target_": "flextok.model.postprocessors.heads.LinearHead",
114
+ "read_key": "enc_registers",
115
+ "write_key": "enc_registers",
116
+ "dim": 1152,
117
+ "dim_out": 6,
118
+ "use_mup_readout": false,
119
+ "weight_init_style": "zero",
120
+ "dtype_override": null
121
+ }
122
+ }
123
+ },
124
+ "flow_matching_noise_module": {
125
+ "_target_": "flextok.flow_matching.noise_modules.MinRFNoiseModule",
126
+ "clean_images_read_key": "vae_latents",
127
+ "noised_images_write_key": "vae_latents_noised",
128
+ "noise_write_key": "flow_noise",
129
+ "timesteps_write_key": "timesteps",
130
+ "sigmas_write_key": "sigmas",
131
+ "ln": false,
132
+ "stratisfied": false,
133
+ "mode_scale": 0.25
134
+ },
135
+ "decoder": {
136
+ "_target_": "flextok.model.utils.wrappers.SequentialModuleDictWrapper",
137
+ "module_dict": {
138
+ "dec_from_latents": {
139
+ "_target_": "flextok.model.preprocessors.linear.LinearLayer",
140
+ "read_key": "enc_registers_quant",
141
+ "write_key": "dec_registers_proj",
142
+ "dim_in": 6,
143
+ "dim": 1792
144
+ },
145
+ "dec_registers_posemb_module": {
146
+ "_target_": "flextok.model.utils.posembs.PositionalEmbeddingAdder",
147
+ "read_key": "dec_registers_proj",
148
+ "write_key": "dec_registers_proj",
149
+ "dim": 1792,
150
+ "max_sizes": [
151
+ 256
152
+ ],
153
+ "posemb_type": "learnable_sum",
154
+ "posemb_scaling": "absolute"
155
+ },
156
+ "dec_nested_dropout": {
157
+ "_target_": "flextok.model.preprocessors.token_dropout.MaskedNestedDropout",
158
+ "read_write_key": "dec_registers_proj",
159
+ "dim": 1792,
160
+ "size_sampling_mode": "pow2"
161
+ },
162
+ "dec_latent_dropout": {
163
+ "_target_": "flextok.model.preprocessors.nullcond.LearnedNullCond",
164
+ "read_write_key": "dec_registers_proj",
165
+ "dim": 1792,
166
+ "dropout_prob": 0.2
167
+ },
168
+ "dec_noise_channels_to_last": {
169
+ "_target_": "flextok.model.utils.dict_ops.PerSampleOp",
170
+ "read_key": "vae_latents_noised",
171
+ "write_key": "vae_latents_noised_bhwc",
172
+ "per_sample_op": {
173
+ "_target_": "flextok.model.utils.dict_ops.channels_first_to_last",
174
+ "_partial_": true
175
+ }
176
+ },
177
+ "dec_noise_patch_emb": {
178
+ "_target_": "flextok.model.preprocessors.patching.PatchEmbedder",
179
+ "input_tensor_list_read_key": "vae_latents_noised_bhwc",
180
+ "patches_list_write_key": "vae_latents_noised_patched",
181
+ "n_patches_write_key": "dec_n_patches",
182
+ "channels_in": 16,
183
+ "dim": 1792,
184
+ "patch_sizes": [
185
+ 2,
186
+ 2
187
+ ],
188
+ "flatten_patches": false
189
+ },
190
+ "dec_patches_posemb_module": {
191
+ "_target_": "flextok.model.utils.posembs.PositionalEmbeddingAdder",
192
+ "read_key": "vae_latents_noised_patched",
193
+ "write_key": "dec_patches",
194
+ "dim": 1792,
195
+ "max_sizes": [
196
+ 16,
197
+ 16
198
+ ],
199
+ "posemb_type": "sincos",
200
+ "posemb_scaling": "absolute"
201
+ },
202
+ "dec_seq_packer": {
203
+ "_target_": "flextok.model.preprocessors.flex_seq_packing.BlockWiseSequencePacker",
204
+ "input_list_read_keys": [
205
+ "dec_patches",
206
+ "dec_registers_proj"
207
+ ],
208
+ "packed_seq_write_key": "dec_packed_seq",
209
+ "block_mask_write_key": "dec_block_mask",
210
+ "inner_packed_shapes_write_key": "dec_ps_inner",
211
+ "outer_packed_shapes_write_key": "dec_ps_outer",
212
+ "emb_packing_fn_write_key": "emb_packing_fn",
213
+ "mask_mode": "full",
214
+ "pad_to_multiple": 128,
215
+ "per_subseq_embs": true
216
+ },
217
+ "dec_time_embedder": {
218
+ "_target_": "flextok.model.preprocessors.time_embedding.TimestepEmbedder",
219
+ "timesteps_read_key": "timesteps",
220
+ "time_embedding_write_key": "dec_temb",
221
+ "dim": 1792,
222
+ "frequency_embedding_size": 256,
223
+ "max_timestep": 1000.0
224
+ },
225
+ "dec_transformer": {
226
+ "_target_": "flextok.model.trunks.transformers.FlexTransformer",
227
+ "input_seq_read_key": "dec_packed_seq",
228
+ "output_seq_write_key": "dec_packed_seq",
229
+ "dim": 1792,
230
+ "depth": 28,
231
+ "block_mask_read_key": "dec_block_mask",
232
+ "adaLN_emb_read_key": "dec_temb",
233
+ "adaLN_packing_fn_read_key": "emb_packing_fn",
234
+ "adaLN_expansion": 2,
235
+ "intermediate_layer_write_key": "dec_packed_seq_repa_layer",
236
+ "intermediate_layers": [
237
+ 1
238
+ ],
239
+ "use_act_checkpoint": true
240
+ },
241
+ "dec_unpacker": {
242
+ "_target_": "flextok.model.postprocessors.seq_unpacking.SequenceUnpacker",
243
+ "packed_seq_read_key": "dec_packed_seq",
244
+ "inner_seq_write_keys": [
245
+ "dec_patches",
246
+ "dec_registers_proj"
247
+ ],
248
+ "inner_packed_shapes_read_key": "dec_ps_inner",
249
+ "outer_packed_shapes_read_key": "dec_ps_outer"
250
+ },
251
+ "dec_repa_unpacker": {
252
+ "_target_": "flextok.model.postprocessors.seq_unpacking.SequenceUnpacker",
253
+ "packed_seq_read_key": "dec_packed_seq_repa_layer",
254
+ "inner_seq_write_keys": [
255
+ "dec_patches_repa_layer",
256
+ "dec_registers_repa_layer"
257
+ ],
258
+ "inner_packed_shapes_read_key": "dec_ps_inner",
259
+ "outer_packed_shapes_read_key": "dec_ps_outer"
260
+ },
261
+ "dec_to_patches": {
262
+ "_target_": "flextok.model.postprocessors.heads.ToPatchesLinearHead",
263
+ "read_key": "dec_patches",
264
+ "write_key": "dec_patches",
265
+ "dim": 1792,
266
+ "channels_out": 16,
267
+ "patch_sizes": [
268
+ 2,
269
+ 2
270
+ ],
271
+ "use_mup_readout": false,
272
+ "weight_init_style": "zero",
273
+ "adaLN_emb_read_key": "dec_temb"
274
+ },
275
+ "dec_channels_to_first": {
276
+ "_target_": "flextok.model.utils.dict_ops.PerSampleOp",
277
+ "read_key": "dec_patches",
278
+ "write_key": "vae_latents_reconst",
279
+ "per_sample_op": {
280
+ "_target_": "flextok.model.utils.dict_ops.channels_last_to_first",
281
+ "_partial_": true
282
+ }
283
+ }
284
+ }
285
+ },
286
+ "_target_": "flextok.flextok_wrapper.FlexTok",
287
+ "pipeline": {
288
+ "_target_": "flextok.flow_matching.pipelines.MinRFPipeline",
289
+ "_partial_": true,
290
+ "target_sizes_read_key": null,
291
+ "latents_read_key": "enc_registers_quant",
292
+ "timesteps_read_key": "timesteps",
293
+ "noised_images_read_key": "vae_latents_noised",
294
+ "reconst_write_key": "vae_latents_reconst",
295
+ "out_channels": 16
296
+ }
297
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:25ee13b1c7bffbbff832051bc8269b5a8293f0cebd6d6f9bd0a4e42f49531143
3
+ size 10163625244