Spaces:
Running
on
Zero
Running
on
Zero
| from typing import List | |
| import os | |
| import spaces | |
| import gradio as gr | |
| import random | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| import einops | |
| import numpy as np | |
| import torch | |
| from torchvision import transforms | |
| import torchvision.transforms.functional as TF | |
| from flextok.flextok_wrapper import FlexTokFromHub | |
| from flextok.utils.demo import imgs_from_urls, denormalize, batch_to_pil | |
| from flextok.utils.misc import detect_bf16_support, get_bf16_context, get_generator | |
| # We recommend running this demo on an A100 GPU | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| gpu_type = torch.cuda.get_device_name(torch.cuda.current_device()) | |
| power_device = f"{gpu_type}" | |
| torch.cuda.max_memory_allocated(device=device) | |
| # Detect if bf16 is enabled or not | |
| enable_bf16 = detect_bf16_support() | |
| print(f'Device: {device}, GPU type: {gpu_type}') | |
| print('BF16 enabled:', enable_bf16) | |
| else: | |
| # Currently not supported. Please run on GPUs. | |
| device, power_device, enable_bf16 = "cpu", "CPU", False | |
| print('Running on CPU') | |
| # The flag below controls whether to allow TF32 on matmul. This flag defaults to False in PyTorch 1.12 and later. | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True. | |
| torch.backends.cudnn.allow_tf32 = True | |
| # Global no_grad | |
| torch.set_grad_enabled(False) | |
| K_KEEP_LIST = [1, 2, 4, 8, 16, 32, 64, 128, 256] | |
| MAX_SEED = np.iinfo(np.int32).max | |
| MODEL_ID = 'EPFL-VILAB/flextok_d18_d28_dfn' | |
| MODEL_NAME = 'FlexTok d18-d28 (DFN)' | |
| # Load FlexTok model from HF Hub | |
| flextok_model = FlexTokFromHub.from_pretrained(MODEL_ID).to(device).eval() | |
| # Disable flex_attention for HF Space | |
| flextok_model.encoder.module_dict.enc_seq_packer.return_materialized_mask = True | |
| flextok_model.decoder.module_dict.dec_seq_packer.return_materialized_mask = True | |
| for block in flextok_model.encoder.module_dict.enc_transformer.blocks: | |
| block._checkpoint_wrapped_module.attn.use_flex_attention = False | |
| for block in flextok_model.decoder.module_dict.dec_transformer.blocks: | |
| block._checkpoint_wrapped_module.attn.use_flex_attention = False | |
| # Load AuraSR model from HF Hub | |
| try: | |
| from aura_sr import AuraSR | |
| aura_sr = AuraSR.from_pretrained("fal-ai/AuraSR") | |
| except: | |
| aura_sr = None | |
| def img_from_path( | |
| path: str, | |
| img_size: int = 256, | |
| mean: List[float] = [0.5, 0.5, 0.5], | |
| std: List[float] = [0.5, 0.5, 0.5], | |
| ) -> torch.Tensor: | |
| # Image loading helper function | |
| img_pil = Image.open(path).convert("RGB") | |
| transform = transforms.Compose( | |
| [ | |
| transforms.Resize(img_size), | |
| transforms.CenterCrop(img_size), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=mean, std=std), | |
| ] | |
| ) | |
| return transform(img_pil).unsqueeze(0) | |
| def infer(img_path, seed=1000, randomize_seed=False, timesteps=25, cfg_scale=7.5, perform_norm_guidance=True, super_res=False): | |
| if randomize_seed: | |
| seed = None | |
| imgs = img_from_path(img_path).to(device) | |
| # Tokenize images once | |
| with get_bf16_context(enable_bf16): | |
| tokens = flextok_model.tokenize(imgs)[0] # 1x256 | |
| # Create all token subsequences | |
| subseq_list = [tokens[:,:k_keep].clone() for k_keep in K_KEEP_LIST] # [1x1, 1x2, 1x4, ..., 1x256] | |
| # Detokenize various subsequences in parallel. Batch size is 9. | |
| with get_bf16_context(enable_bf16): | |
| generator = get_generator(seed=seed, device=device) | |
| all_reconst = flextok_model.detokenize( | |
| subseq_list, timesteps=timesteps, | |
| guidance_scale=cfg_scale, perform_norm_guidance=perform_norm_guidance, | |
| generator=generator, verbose=False, | |
| ) | |
| # Transform to PIL images | |
| all_images = [ | |
| ( | |
| TF.to_pil_image(denormalize(reconst_k).clamp(0,1)), | |
| '1 token (2 bytes)' if k_keep == 1 else f'{k_keep} tokens ({2*k_keep} bytes)' | |
| ) | |
| for reconst_k, k_keep in zip(all_reconst, K_KEEP_LIST) | |
| ] | |
| if super_res: | |
| all_images = [(aura_sr.upscale_4x(img), label) for img, label in all_images] | |
| return all_images | |
| examples = [ | |
| 'examples/0.png', 'examples/1.png', 'examples/2.png', | |
| 'examples/3.png', 'examples/4.png', 'examples/5.png', | |
| ] | |
| css=""" | |
| #col-container { | |
| margin: 0 auto; | |
| max-width: 1500px; | |
| } | |
| #col-input-container { | |
| margin: 0 auto; | |
| max-width: 400px; | |
| } | |
| #run-button { | |
| margin: 0 auto; | |
| } | |
| #gallery { | |
| aspect-ratio: 1/1 !important; | |
| height: auto !important; | |
| } | |
| """ | |
| with gr.Blocks(css=css, theme=gr.themes.Base()) as demo: | |
| with gr.Column(elem_id="col-container"): | |
| gr.Markdown(f""" | |
| # FlexTok: Resampling Images into 1D Token Sequences of Flexible Length | |
| """) | |
| with gr.Row(): | |
| with gr.Column(elem_id="col-input-container"): | |
| gr.Markdown(f""" | |
| [`Website`](https://flextok.epfl.ch) | [`arXiv`](https://arxiv.org/abs/2502.13967) | [`GitHub`](https://github.com/apple/ml-flextok) | |
| Research demo for: <br> | |
| [**FlexTok: Resampling Images into 1D Token Sequences of Flexible Length**](https://arxiv.org/abs/2502.13967), arXiv 2025 <br> | |
| This demo uses the FlexTok tokenizer to autoencode the given RGB input, using [{MODEL_ID}](https://huggingface.co/{MODEL_ID}), running on *{power_device}*. | |
| The FlexTok encoder produces a 1D sequence of discrete tokens that are ordered in a coarse-to-fine manner. | |
| We show reconstructions from truncated subsequences, using the first 1, 2, 4, 8, ..., 256 tokens. | |
| As you will see, the first tokens capture more high-level semantic content, while subsequent ones add fine-grained detail. | |
| """) | |
| img_path = gr.Image(label='RGB input image', type='filepath') | |
| run_button = gr.Button(f"Autoencode with {MODEL_NAME}", scale=0, elem_id="run-button") | |
| with gr.Accordion("Advanced Settings", open=False): | |
| gr.Markdown(f""" | |
| The FlexTok decoder is a rectified flow model. The following settings control the seed of the initial noise, the number of denoising timesteps, | |
| the guidance scale, and whether to perform [Adaptive Projected Guidance](https://arxiv.org/abs/2410.02416) (we recommend enabling it). | |
| This FlexTok model operates at 256x256 resolution. You can optionally super-resolve the reconstructions to 1024x1024 using | |
| [Aura-SR](https://huggingface.co/fal/AuraSR) for sharper details, whithout changing the underlying reconstructed image too much. | |
| """) | |
| seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=1000) | |
| randomize_seed = gr.Checkbox(label="Randomize seed", value=False) | |
| timesteps = gr.Slider(label="Denoising timesteps", minimum=1, maximum=1000, step=1, value=25) | |
| cfg_scale = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=15.0, step=0.1, value=7.5) | |
| perform_norm_guidance = gr.Checkbox(label="Perform Adaptive Projected Guidance", value=True) | |
| super_res = gr.Checkbox(label="Super-resolve reconstructions from 256x256 to 1024x1024 with Aura-SR", value=False) | |
| result = gr.Gallery( | |
| label="Reconstructions", show_label=True, elem_id="gallery", type='pil', | |
| columns=[3], rows=None, object_fit="contain", height=800 | |
| ) | |
| gr.Examples( | |
| examples = examples, | |
| fn = infer, | |
| inputs = [img_path], | |
| outputs = [result], | |
| cache_examples='lazy', | |
| ) | |
| run_button.click( | |
| fn = infer, | |
| inputs = [img_path, seed, randomize_seed, timesteps, cfg_scale, perform_norm_guidance, super_res], | |
| outputs = [result] | |
| ) | |
| demo.queue(max_size=10).launch(share=True) |