Add model card for SphereAR

#1
by nielsr HF Staff - opened
Files changed (1) hide show
  1. README.md +63 -0
README.md ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ pipeline_tag: text-to-image
3
+ ---
4
+
5
+ # SphereAR: Hyperspherical Latents Improve Continuous-Token Autoregressive Generation
6
+
7
+ This repository contains the official PyTorch implementation of the paper [Hyperspherical Latents Improve Continuous-Token Autoregressive Generation](https://huggingface.co/papers/2509.24335).
8
+
9
+ <p align="center">
10
+ <img src="https://github.com/guolinke/SphereAR/raw/main/figures/grid.jpg" width=780>
11
+ </p>
12
+
13
+ ## Introduction
14
+
15
+ <p align="center"><img src="https://github.com/guolinke/SphereAR/raw/main/figures/overview.png" width=553><img src="https://github.com/guolinke/SphereAR/raw/main/figures/fid_vs_params.png" width=246></p>
16
+
17
+ SphereAR is a simple yet effective approach to continuous-token autoregressive (AR) image generation. It makes AR scale-invariant by constraining all AR inputs and outputs---**including after CFG**---to lie on a fixed-radius hypersphere (constant $\ell_2$ norm) via hyperspherical VAEs. This theoretical insight shows that the hyperspherical constraint removes the scale component, which is the primary cause of variance collapse, thereby stabilizing AR decoding.
18
+
19
+ The model is a **pure next-token** AR generator with **raster** order, matching standard language AR modeling. On ImageNet 256×256, SphereAR-H (943M) achieves a state-of-the-art FID of **1.34** among AR image generators. Even at smaller scales, SphereAR-L (479M) reaches FID 1.54 and SphereAR-B (208M) reaches 1.92, matching or surpassing much larger baselines.
20
+
21
+ For more details on the implementation, environment setup, and advanced usage, please refer to the [official GitHub repository](https://github.com/guolinke/SphereAR).
22
+
23
+ ## Model Checkpoints
24
+
25
+ Pre-trained model checkpoints are available on Hugging Face:
26
+
27
+ | Name | params | FID (256x256) | weight |
28
+ | :--------- | :----: | :-----------: | :------------------------------------------------------------------------ |
29
+ | S-VAE | 75M | - | [vae.pt](https://huggingface.co/guolinke/SphereAR/blob/main/vae.pt) |
30
+ | SphereAR-B | 208M | 1.92 | [SphereAR_B.pt](https://huggingface.co/guolinke/SphereAR/blob/main/SphereAR_B.pt) |
31
+ | SphereAR-L | 479M | 1.54 | [SphereAR_L.pt](https://huggingface.co/guolinke/SphereAR/blob/main/SphereAR_L.pt) |
32
+ | SphereAR-H | 943M | 1.34 | [SphereAR_H.pt](https://huggingface.co/guolinke/SphereAR/blob/main/SphereAR_H.pt) |
33
+
34
+ ## Sample Usage: Class-conditional Image Generation
35
+
36
+ To sample 50,000 images using the `SphereAR-H` checkpoint for evaluation, you can use the following command adapted from the official repository. This requires a distributed setup (`torchrun`).
37
+
38
+ ```shell
39
+ # First, download the SphereAR-H checkpoint (SphereAR_H.pt) and the S-VAE checkpoint (vae.pt)
40
+ # from the links in the "Model Checkpoints" table above.
41
+
42
+ ckpt=your_path_to/SphereAR_H.pt # Path to your downloaded SphereAR_H.pt checkpoint
43
+ result_path=your_result_directory # Directory to save generated images
44
+
45
+ torchrun --nnodes=1 --nproc_per_node=8 --node_rank=0 \
46
+ sample_ddp.py --model SphereAR-H --ckpt $ckpt --cfg-scale 4.5 \
47
+ --sample-dir $result_path --per-proc-batch-size 256 --to-npz
48
+ ```
49
+
50
+ *Note: The `sample_ddp.py` script and its dependencies can be found in the [official GitHub repository](https://github.com/guolinke/SphereAR). Ensure your environment is set up according to their instructions, including PyTorch and FlashAttention.*
51
+
52
+ ## Citation
53
+
54
+ If you find our work useful, please consider citing the paper:
55
+
56
+ ```bibtex
57
+ @article{ke2025hyperspherical,
58
+ title={Hyperspherical Latents Improve Continuous-Token Autoregressive Generation},
59
+ author={Guolin Ke and Hui Xue},
60
+ journal={arXiv preprint arXiv:2509.24335},
61
+ year={2025}
62
+ }
63
+ ```