from typing import List import os import spaces import gradio as gr import random from PIL import Image import matplotlib.pyplot as plt import einops import numpy as np import torch from torchvision import transforms import torchvision.transforms.functional as TF from flextok.flextok_wrapper import FlexTokFromHub from flextok.utils.demo import imgs_from_urls, denormalize, batch_to_pil from flextok.utils.misc import detect_bf16_support, get_bf16_context, get_generator # We recommend running this demo on an A100 GPU if torch.cuda.is_available(): device = "cuda" gpu_type = torch.cuda.get_device_name(torch.cuda.current_device()) power_device = f"{gpu_type}" torch.cuda.max_memory_allocated(device=device) # Detect if bf16 is enabled or not enable_bf16 = detect_bf16_support() print(f'Device: {device}, GPU type: {gpu_type}') print('BF16 enabled:', enable_bf16) else: # Currently not supported. Please run on GPUs. device, power_device, enable_bf16 = "cpu", "CPU", False print('Running on CPU') # The flag below controls whether to allow TF32 on matmul. This flag defaults to False in PyTorch 1.12 and later. torch.backends.cuda.matmul.allow_tf32 = True # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True. torch.backends.cudnn.allow_tf32 = True # Global no_grad torch.set_grad_enabled(False) K_KEEP_LIST = [1, 2, 4, 8, 16, 32, 64, 128, 256] MAX_SEED = np.iinfo(np.int32).max MODEL_ID = 'EPFL-VILAB/flextok_d18_d28_dfn' MODEL_NAME = 'FlexTok d18-d28 (DFN)' # Load FlexTok model from HF Hub flextok_model = FlexTokFromHub.from_pretrained(MODEL_ID).to(device).eval() # Disable flex_attention for HF Space flextok_model.encoder.module_dict.enc_seq_packer.return_materialized_mask = True flextok_model.decoder.module_dict.dec_seq_packer.return_materialized_mask = True for block in flextok_model.encoder.module_dict.enc_transformer.blocks: block._checkpoint_wrapped_module.attn.use_flex_attention = False for block in flextok_model.decoder.module_dict.dec_transformer.blocks: block._checkpoint_wrapped_module.attn.use_flex_attention = False # Load AuraSR model from HF Hub try: from aura_sr import AuraSR aura_sr = AuraSR.from_pretrained("fal-ai/AuraSR") except: aura_sr = None def img_from_path( path: str, img_size: int = 256, mean: List[float] = [0.5, 0.5, 0.5], std: List[float] = [0.5, 0.5, 0.5], ) -> torch.Tensor: # Image loading helper function img_pil = Image.open(path).convert("RGB") transform = transforms.Compose( [ transforms.Resize(img_size), transforms.CenterCrop(img_size), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std), ] ) return transform(img_pil).unsqueeze(0) @spaces.GPU(duration=30) def infer(img_path, seed=1000, randomize_seed=False, timesteps=25, cfg_scale=7.5, perform_norm_guidance=True, super_res=False): if randomize_seed: seed = None imgs = img_from_path(img_path).to(device) # Tokenize images once with get_bf16_context(enable_bf16): tokens = flextok_model.tokenize(imgs)[0] # 1x256 # Create all token subsequences subseq_list = [tokens[:,:k_keep].clone() for k_keep in K_KEEP_LIST] # [1x1, 1x2, 1x4, ..., 1x256] # Detokenize various subsequences in parallel. Batch size is 9. with get_bf16_context(enable_bf16): generator = get_generator(seed=seed, device=device) all_reconst = flextok_model.detokenize( subseq_list, timesteps=timesteps, guidance_scale=cfg_scale, perform_norm_guidance=perform_norm_guidance, generator=generator, verbose=False, ) # Transform to PIL images all_images = [ ( TF.to_pil_image(denormalize(reconst_k).clamp(0,1)), '1 token (2 bytes)' if k_keep == 1 else f'{k_keep} tokens ({2*k_keep} bytes)' ) for reconst_k, k_keep in zip(all_reconst, K_KEEP_LIST) ] if super_res: all_images = [(aura_sr.upscale_4x(img), label) for img, label in all_images] return all_images examples = [ 'examples/0.png', 'examples/1.png', 'examples/2.png', 'examples/3.png', 'examples/4.png', 'examples/5.png', ] css=""" #col-container { margin: 0 auto; max-width: 1500px; } #col-input-container { margin: 0 auto; max-width: 400px; } #run-button { margin: 0 auto; } #gallery { aspect-ratio: 1/1 !important; height: auto !important; } """ with gr.Blocks(css=css, theme=gr.themes.Base()) as demo: with gr.Column(elem_id="col-container"): gr.Markdown(f""" # FlexTok: Resampling Images into 1D Token Sequences of Flexible Length """) with gr.Row(): with gr.Column(elem_id="col-input-container"): gr.Markdown(f""" [`Website`](https://flextok.epfl.ch) | [`arXiv`](https://arxiv.org/abs/2502.13967) | [`GitHub`](https://github.com/apple/ml-flextok) Research demo for:
[**FlexTok: Resampling Images into 1D Token Sequences of Flexible Length**](https://arxiv.org/abs/2502.13967), arXiv 2025
This demo uses the FlexTok tokenizer to autoencode the given RGB input, using [{MODEL_ID}](https://huggingface.co/{MODEL_ID}), running on *{power_device}*. The FlexTok encoder produces a 1D sequence of discrete tokens that are ordered in a coarse-to-fine manner. We show reconstructions from truncated subsequences, using the first 1, 2, 4, 8, ..., 256 tokens. As you will see, the first tokens capture more high-level semantic content, while subsequent ones add fine-grained detail. """) img_path = gr.Image(label='RGB input image', type='filepath') run_button = gr.Button(f"Autoencode with {MODEL_NAME}", scale=0, elem_id="run-button") with gr.Accordion("Advanced Settings", open=False): gr.Markdown(f""" The FlexTok decoder is a rectified flow model. The following settings control the seed of the initial noise, the number of denoising timesteps, the guidance scale, and whether to perform [Adaptive Projected Guidance](https://arxiv.org/abs/2410.02416) (we recommend enabling it). This FlexTok model operates at 256x256 resolution. You can optionally super-resolve the reconstructions to 1024x1024 using [Aura-SR](https://huggingface.co/fal/AuraSR) for sharper details, whithout changing the underlying reconstructed image too much. """) seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=1000) randomize_seed = gr.Checkbox(label="Randomize seed", value=False) timesteps = gr.Slider(label="Denoising timesteps", minimum=1, maximum=1000, step=1, value=25) cfg_scale = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=15.0, step=0.1, value=7.5) perform_norm_guidance = gr.Checkbox(label="Perform Adaptive Projected Guidance", value=True) super_res = gr.Checkbox(label="Super-resolve reconstructions from 256x256 to 1024x1024 with Aura-SR", value=False) result = gr.Gallery( label="Reconstructions", show_label=True, elem_id="gallery", type='pil', columns=[3], rows=None, object_fit="contain", height=800 ) gr.Examples( examples = examples, fn = infer, inputs = [img_path], outputs = [result], cache_examples='lazy', ) run_button.click( fn = infer, inputs = [img_path, seed, randomize_seed, timesteps, cfg_scale, perform_norm_guidance, super_res], outputs = [result] ) demo.queue(max_size=10).launch(share=True)