Spaces:
Running
on
Zero
Running
on
Zero
from typing import List | |
import os | |
import spaces | |
import gradio as gr | |
import random | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
import einops | |
import numpy as np | |
import torch | |
from torchvision import transforms | |
import torchvision.transforms.functional as TF | |
from flextok.flextok_wrapper import FlexTokFromHub | |
from flextok.utils.demo import imgs_from_urls, denormalize, batch_to_pil | |
from flextok.utils.misc import detect_bf16_support, get_bf16_context, get_generator | |
# We recommend running this demo on an A100 GPU | |
if torch.cuda.is_available(): | |
device = "cuda" | |
gpu_type = torch.cuda.get_device_name(torch.cuda.current_device()) | |
power_device = f"{gpu_type}" | |
torch.cuda.max_memory_allocated(device=device) | |
# Detect if bf16 is enabled or not | |
enable_bf16 = detect_bf16_support() | |
print(f'Device: {device}, GPU type: {gpu_type}') | |
print('BF16 enabled:', enable_bf16) | |
else: | |
# Currently not supported. Please run on GPUs. | |
device, power_device, enable_bf16 = "cpu", "CPU", False | |
print('Running on CPU') | |
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False in PyTorch 1.12 and later. | |
torch.backends.cuda.matmul.allow_tf32 = True | |
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True. | |
torch.backends.cudnn.allow_tf32 = True | |
# Global no_grad | |
torch.set_grad_enabled(False) | |
K_KEEP_LIST = [1, 2, 4, 8, 16, 32, 64, 128, 256] | |
MAX_SEED = np.iinfo(np.int32).max | |
MODEL_ID = 'EPFL-VILAB/flextok_d18_d28_dfn' | |
MODEL_NAME = 'FlexTok d18-d28 (DFN)' | |
# Load FlexTok model from HF Hub | |
flextok_model = FlexTokFromHub.from_pretrained(MODEL_ID).to(device).eval() | |
# Disable flex_attention for HF Space | |
flextok_model.encoder.module_dict.enc_seq_packer.return_materialized_mask = True | |
flextok_model.decoder.module_dict.dec_seq_packer.return_materialized_mask = True | |
for block in flextok_model.encoder.module_dict.enc_transformer.blocks: | |
block._checkpoint_wrapped_module.attn.use_flex_attention = False | |
for block in flextok_model.decoder.module_dict.dec_transformer.blocks: | |
block._checkpoint_wrapped_module.attn.use_flex_attention = False | |
# Load AuraSR model from HF Hub | |
try: | |
from aura_sr import AuraSR | |
aura_sr = AuraSR.from_pretrained("fal-ai/AuraSR") | |
except: | |
aura_sr = None | |
def img_from_path( | |
path: str, | |
img_size: int = 256, | |
mean: List[float] = [0.5, 0.5, 0.5], | |
std: List[float] = [0.5, 0.5, 0.5], | |
) -> torch.Tensor: | |
# Image loading helper function | |
img_pil = Image.open(path).convert("RGB") | |
transform = transforms.Compose( | |
[ | |
transforms.Resize(img_size), | |
transforms.CenterCrop(img_size), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=mean, std=std), | |
] | |
) | |
return transform(img_pil).unsqueeze(0) | |
def infer(img_path, seed=1000, randomize_seed=False, timesteps=25, cfg_scale=7.5, perform_norm_guidance=True, super_res=False): | |
if randomize_seed: | |
seed = None | |
imgs = img_from_path(img_path).to(device) | |
# Tokenize images once | |
with get_bf16_context(enable_bf16): | |
tokens = flextok_model.tokenize(imgs)[0] # 1x256 | |
# Create all token subsequences | |
subseq_list = [tokens[:,:k_keep].clone() for k_keep in K_KEEP_LIST] # [1x1, 1x2, 1x4, ..., 1x256] | |
# Detokenize various subsequences in parallel. Batch size is 9. | |
with get_bf16_context(enable_bf16): | |
generator = get_generator(seed=seed, device=device) | |
all_reconst = flextok_model.detokenize( | |
subseq_list, timesteps=timesteps, | |
guidance_scale=cfg_scale, perform_norm_guidance=perform_norm_guidance, | |
generator=generator, verbose=False, | |
) | |
# Transform to PIL images | |
all_images = [ | |
( | |
TF.to_pil_image(denormalize(reconst_k).clamp(0,1)), | |
'1 token (2 bytes)' if k_keep == 1 else f'{k_keep} tokens ({2*k_keep} bytes)' | |
) | |
for reconst_k, k_keep in zip(all_reconst, K_KEEP_LIST) | |
] | |
if super_res: | |
all_images = [(aura_sr.upscale_4x(img), label) for img, label in all_images] | |
return all_images | |
examples = [ | |
'examples/0.png', 'examples/1.png', 'examples/2.png', | |
'examples/3.png', 'examples/4.png', 'examples/5.png', | |
] | |
css=""" | |
#col-container { | |
margin: 0 auto; | |
max-width: 1500px; | |
} | |
#col-input-container { | |
margin: 0 auto; | |
max-width: 400px; | |
} | |
#run-button { | |
margin: 0 auto; | |
} | |
#gallery { | |
aspect-ratio: 1/1 !important; | |
height: auto !important; | |
} | |
""" | |
with gr.Blocks(css=css, theme=gr.themes.Base()) as demo: | |
with gr.Column(elem_id="col-container"): | |
gr.Markdown(f""" | |
# FlexTok: Resampling Images into 1D Token Sequences of Flexible Length | |
""") | |
with gr.Row(): | |
with gr.Column(elem_id="col-input-container"): | |
gr.Markdown(f""" | |
[`Website`](https://flextok.epfl.ch) | [`arXiv`](https://arxiv.org/abs/2502.13967) | [`GitHub`](https://github.com/apple/ml-flextok) | |
Research demo for: <br> | |
[**FlexTok: Resampling Images into 1D Token Sequences of Flexible Length**](https://arxiv.org/abs/2502.13967), arXiv 2025 <br> | |
This demo uses the FlexTok tokenizer to autoencode the given RGB input, using [{MODEL_ID}](https://huggingface.co/{MODEL_ID}), running on *{power_device}*. | |
The FlexTok encoder produces a 1D sequence of discrete tokens that are ordered in a coarse-to-fine manner. | |
We show reconstructions from truncated subsequences, using the first 1, 2, 4, 8, ..., 256 tokens. | |
As you will see, the first tokens capture more high-level semantic content, while subsequent ones add fine-grained detail. | |
""") | |
img_path = gr.Image(label='RGB input image', type='filepath') | |
run_button = gr.Button(f"Autoencode with {MODEL_NAME}", scale=0, elem_id="run-button") | |
with gr.Accordion("Advanced Settings", open=False): | |
gr.Markdown(f""" | |
The FlexTok decoder is a rectified flow model. The following settings control the seed of the initial noise, the number of denoising timesteps, | |
the guidance scale, and whether to perform [Adaptive Projected Guidance](https://arxiv.org/abs/2410.02416) (we recommend enabling it). | |
This FlexTok model operates at 256x256 resolution. You can optionally super-resolve the reconstructions to 1024x1024 using | |
[Aura-SR](https://huggingface.co/fal/AuraSR) for sharper details, whithout changing the underlying reconstructed image too much. | |
""") | |
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=1000) | |
randomize_seed = gr.Checkbox(label="Randomize seed", value=False) | |
timesteps = gr.Slider(label="Denoising timesteps", minimum=1, maximum=1000, step=1, value=25) | |
cfg_scale = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=15.0, step=0.1, value=7.5) | |
perform_norm_guidance = gr.Checkbox(label="Perform Adaptive Projected Guidance", value=True) | |
super_res = gr.Checkbox(label="Super-resolve reconstructions from 256x256 to 1024x1024 with Aura-SR", value=False) | |
result = gr.Gallery( | |
label="Reconstructions", show_label=True, elem_id="gallery", type='pil', | |
columns=[3], rows=None, object_fit="contain", height=800 | |
) | |
gr.Examples( | |
examples = examples, | |
fn = infer, | |
inputs = [img_path], | |
outputs = [result], | |
cache_examples='lazy', | |
) | |
run_button.click( | |
fn = infer, | |
inputs = [img_path, seed, randomize_seed, timesteps, cfg_scale, perform_norm_guidance, super_res], | |
outputs = [result] | |
) | |
demo.queue(max_size=10).launch(share=True) |