File size: 8,052 Bytes
a4f1fc6
 
 
6862991
 
a4f1fc6
 
6862991
a4f1fc6
 
6862991
a4f1fc6
 
6862991
a4f1fc6
 
 
6862991
a4f1fc6
6862991
a4f1fc6
 
 
 
 
 
d4afa02
 
6862991
7d2453d
a4f1fc6
d4afa02
a4f1fc6
 
 
 
 
 
 
 
 
6862991
5647d41
6862991
a4f1fc6
 
 
 
 
 
9199279
 
 
 
 
 
 
 
a4f1fc6
0c0b8b9
 
 
 
6b03d29
0c0b8b9
 
a4f1fc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a53bced
a17a70f
6862991
a4f1fc6
 
 
 
 
 
5647d41
6862991
a4f1fc6
5647d41
6862991
a4f1fc6
 
 
 
 
 
 
 
 
 
 
c8ecc15
 
1085def
c8ecc15
5647d41
a4f1fc6
6862991
0c0b8b9
 
 
a4f1fc6
6862991
 
 
a4f1fc6
 
6862991
 
a4f1fc6
6862991
 
a4f1fc6
 
 
 
 
 
 
 
 
 
 
 
6862991
 
 
a4f1fc6
 
6862991
a4f1fc6
 
 
 
6862991
a4f1fc6
 
 
 
26aa516
a4f1fc6
 
0c0b8b9
 
 
 
a4f1fc6
 
 
 
 
 
 
0c0b8b9
 
 
e6a3880
 
a4f1fc6
b2e1f99
a4f1fc6
b2e1f99
a4f1fc6
 
a17a70f
a4f1fc6
 
 
 
6862991
 
a4f1fc6
 
 
 
 
 
 
 
 
 
0c0b8b9
a4f1fc6
6862991
 
f8de4d8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
from typing import List
import os
import spaces
import gradio as gr
import random
from PIL import Image
import matplotlib.pyplot as plt

import einops
import numpy as np
import torch
from torchvision import transforms
import torchvision.transforms.functional as TF

from flextok.flextok_wrapper import FlexTokFromHub
from flextok.utils.demo import imgs_from_urls, denormalize, batch_to_pil
from flextok.utils.misc import detect_bf16_support, get_bf16_context, get_generator

# We recommend running this demo on an A100 GPU
if torch.cuda.is_available():
    device = "cuda"
    gpu_type = torch.cuda.get_device_name(torch.cuda.current_device())
    power_device = f"{gpu_type}"
    torch.cuda.max_memory_allocated(device=device)
    # Detect if bf16 is enabled or not
    enable_bf16 = detect_bf16_support()
    print(f'Device: {device}, GPU type: {gpu_type}')
    print('BF16 enabled:', enable_bf16)
else:
    # Currently not supported. Please run on GPUs.
    device, power_device, enable_bf16 = "cpu", "CPU", False
    print('Running on CPU')


# The flag below controls whether to allow TF32 on matmul. This flag defaults to False in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Global no_grad
torch.set_grad_enabled(False)

K_KEEP_LIST = [1, 2, 4, 8, 16, 32, 64, 128, 256]
MAX_SEED = np.iinfo(np.int32).max

MODEL_ID = 'EPFL-VILAB/flextok_d18_d28_dfn'
MODEL_NAME = 'FlexTok d18-d28 (DFN)'

# Load FlexTok model from HF Hub
flextok_model = FlexTokFromHub.from_pretrained(MODEL_ID).to(device).eval()

# Disable flex_attention for HF Space
flextok_model.encoder.module_dict.enc_seq_packer.return_materialized_mask = True
flextok_model.decoder.module_dict.dec_seq_packer.return_materialized_mask = True
for block in flextok_model.encoder.module_dict.enc_transformer.blocks:
    block._checkpoint_wrapped_module.attn.use_flex_attention = False
for block in flextok_model.decoder.module_dict.dec_transformer.blocks:
    block._checkpoint_wrapped_module.attn.use_flex_attention = False

# Load AuraSR model from HF Hub
try:
    from aura_sr import AuraSR
    aura_sr = AuraSR.from_pretrained("fal-ai/AuraSR")
except:
    aura_sr = None


def img_from_path(
    path: str,
    img_size: int = 256,
    mean: List[float] = [0.5, 0.5, 0.5],
    std: List[float] = [0.5, 0.5, 0.5],
) -> torch.Tensor:
    # Image loading helper function
    img_pil = Image.open(path).convert("RGB")

    transform = transforms.Compose(
        [
            transforms.Resize(img_size),
            transforms.CenterCrop(img_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std),
        ]
    )

    return transform(img_pil).unsqueeze(0)


@spaces.GPU(duration=30)
def infer(img_path, seed=1000, randomize_seed=False, timesteps=25, cfg_scale=7.5, perform_norm_guidance=True, super_res=False):
    if randomize_seed:
        seed = None

    imgs = img_from_path(img_path).to(device)

    # Tokenize images once
    with get_bf16_context(enable_bf16):
        tokens = flextok_model.tokenize(imgs)[0] # 1x256 

    # Create all token subsequences
    subseq_list = [tokens[:,:k_keep].clone() for k_keep in K_KEEP_LIST] # [1x1, 1x2, 1x4, ..., 1x256]

    # Detokenize various subsequences in parallel. Batch size is 9.
    with get_bf16_context(enable_bf16):
        generator = get_generator(seed=seed, device=device)
        all_reconst = flextok_model.detokenize(
            subseq_list, timesteps=timesteps, 
            guidance_scale=cfg_scale, perform_norm_guidance=perform_norm_guidance,
            generator=generator, verbose=False,
        )
    
    # Transform to PIL images
    all_images = [
        (
            TF.to_pil_image(denormalize(reconst_k).clamp(0,1)), 
            '1 token (2 bytes)' if k_keep == 1 else f'{k_keep} tokens ({2*k_keep} bytes)'
        ) 
        for reconst_k, k_keep in zip(all_reconst, K_KEEP_LIST)
    ]

    if super_res:
        all_images = [(aura_sr.upscale_4x(img), label) for img, label in all_images]

    return all_images


examples = [
    'examples/0.png', 'examples/1.png', 'examples/2.png',
    'examples/3.png', 'examples/4.png', 'examples/5.png',
]

css="""
#col-container {
    margin: 0 auto;
    max-width: 1500px;
}
#col-input-container {
    margin: 0 auto;
    max-width: 400px;
}
#run-button {
    margin: 0 auto;
}
#gallery {
    aspect-ratio: 1/1 !important;
    height: auto !important;
}
"""

with gr.Blocks(css=css, theme=gr.themes.Base()) as demo:
    
    with gr.Column(elem_id="col-container"):
        gr.Markdown(f"""
        # FlexTok: Resampling Images into 1D Token Sequences of Flexible Length
        """)
        
        with gr.Row():
            with gr.Column(elem_id="col-input-container"):
                gr.Markdown(f"""
                [`Website`](https://flextok.epfl.ch) | [`arXiv`](https://arxiv.org/abs/2502.13967) | [`GitHub`](https://github.com/apple/ml-flextok)

                Research demo for: <br>
                [**FlexTok: Resampling Images into 1D Token Sequences of Flexible Length**](https://arxiv.org/abs/2502.13967), arXiv 2025 <br>

                This demo uses the FlexTok tokenizer to autoencode the given RGB input, using [{MODEL_ID}](https://huggingface.co/{MODEL_ID}), running on *{power_device}*. 
                The FlexTok encoder produces a 1D sequence of discrete tokens that are ordered in a coarse-to-fine manner. 
                We show reconstructions from truncated subsequences, using the first 1, 2, 4, 8, ..., 256 tokens. 
                As you will see, the first tokens capture more high-level semantic content, while subsequent ones add fine-grained detail.
                """)
                
                img_path = gr.Image(label='RGB input image', type='filepath')
                run_button = gr.Button(f"Autoencode with {MODEL_NAME}", scale=0, elem_id="run-button")

                with gr.Accordion("Advanced Settings", open=False):
                    gr.Markdown(f"""
                    The FlexTok decoder is a rectified flow model. The following settings control the seed of the initial noise, the number of denoising timesteps, 
                    the guidance scale, and whether to perform [Adaptive Projected Guidance](https://arxiv.org/abs/2410.02416) (we recommend enabling it).
                    
                    This FlexTok model operates at 256x256 resolution. You can optionally super-resolve the reconstructions to 1024x1024 using 
                    [Aura-SR](https://huggingface.co/fal/AuraSR) for sharper details, whithout changing the underlying reconstructed image too much.
                    """)
                    seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=1000)
                    randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
                    timesteps = gr.Slider(label="Denoising timesteps", minimum=1, maximum=1000, step=1, value=25)
                    cfg_scale = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=15.0, step=0.1, value=7.5)
                    perform_norm_guidance = gr.Checkbox(label="Perform Adaptive Projected Guidance", value=True)
                    super_res = gr.Checkbox(label="Super-resolve reconstructions from 256x256 to 1024x1024 with Aura-SR", value=False)
        
            result = gr.Gallery(
                label="Reconstructions", show_label=True, elem_id="gallery", type='pil',
                columns=[3], rows=None, object_fit="contain", height=800
            )

        gr.Examples(
            examples = examples,
            fn = infer,
            inputs = [img_path],
            outputs = [result],
            cache_examples='lazy',
        )

    run_button.click(
        fn = infer,
        inputs = [img_path, seed, randomize_seed, timesteps, cfg_scale, perform_norm_guidance, super_res],
        outputs = [result]
    )

demo.queue(max_size=10).launch(share=True)