Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,052 Bytes
a4f1fc6 6862991 a4f1fc6 6862991 a4f1fc6 6862991 a4f1fc6 6862991 a4f1fc6 6862991 a4f1fc6 6862991 a4f1fc6 d4afa02 6862991 7d2453d a4f1fc6 d4afa02 a4f1fc6 6862991 5647d41 6862991 a4f1fc6 9199279 a4f1fc6 0c0b8b9 6b03d29 0c0b8b9 a4f1fc6 a53bced a17a70f 6862991 a4f1fc6 5647d41 6862991 a4f1fc6 5647d41 6862991 a4f1fc6 c8ecc15 1085def c8ecc15 5647d41 a4f1fc6 6862991 0c0b8b9 a4f1fc6 6862991 a4f1fc6 6862991 a4f1fc6 6862991 a4f1fc6 6862991 a4f1fc6 6862991 a4f1fc6 6862991 a4f1fc6 26aa516 a4f1fc6 0c0b8b9 a4f1fc6 0c0b8b9 e6a3880 a4f1fc6 b2e1f99 a4f1fc6 b2e1f99 a4f1fc6 a17a70f a4f1fc6 6862991 a4f1fc6 0c0b8b9 a4f1fc6 6862991 f8de4d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
from typing import List
import os
import spaces
import gradio as gr
import random
from PIL import Image
import matplotlib.pyplot as plt
import einops
import numpy as np
import torch
from torchvision import transforms
import torchvision.transforms.functional as TF
from flextok.flextok_wrapper import FlexTokFromHub
from flextok.utils.demo import imgs_from_urls, denormalize, batch_to_pil
from flextok.utils.misc import detect_bf16_support, get_bf16_context, get_generator
# We recommend running this demo on an A100 GPU
if torch.cuda.is_available():
device = "cuda"
gpu_type = torch.cuda.get_device_name(torch.cuda.current_device())
power_device = f"{gpu_type}"
torch.cuda.max_memory_allocated(device=device)
# Detect if bf16 is enabled or not
enable_bf16 = detect_bf16_support()
print(f'Device: {device}, GPU type: {gpu_type}')
print('BF16 enabled:', enable_bf16)
else:
# Currently not supported. Please run on GPUs.
device, power_device, enable_bf16 = "cpu", "CPU", False
print('Running on CPU')
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True
# Global no_grad
torch.set_grad_enabled(False)
K_KEEP_LIST = [1, 2, 4, 8, 16, 32, 64, 128, 256]
MAX_SEED = np.iinfo(np.int32).max
MODEL_ID = 'EPFL-VILAB/flextok_d18_d28_dfn'
MODEL_NAME = 'FlexTok d18-d28 (DFN)'
# Load FlexTok model from HF Hub
flextok_model = FlexTokFromHub.from_pretrained(MODEL_ID).to(device).eval()
# Disable flex_attention for HF Space
flextok_model.encoder.module_dict.enc_seq_packer.return_materialized_mask = True
flextok_model.decoder.module_dict.dec_seq_packer.return_materialized_mask = True
for block in flextok_model.encoder.module_dict.enc_transformer.blocks:
block._checkpoint_wrapped_module.attn.use_flex_attention = False
for block in flextok_model.decoder.module_dict.dec_transformer.blocks:
block._checkpoint_wrapped_module.attn.use_flex_attention = False
# Load AuraSR model from HF Hub
try:
from aura_sr import AuraSR
aura_sr = AuraSR.from_pretrained("fal-ai/AuraSR")
except:
aura_sr = None
def img_from_path(
path: str,
img_size: int = 256,
mean: List[float] = [0.5, 0.5, 0.5],
std: List[float] = [0.5, 0.5, 0.5],
) -> torch.Tensor:
# Image loading helper function
img_pil = Image.open(path).convert("RGB")
transform = transforms.Compose(
[
transforms.Resize(img_size),
transforms.CenterCrop(img_size),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
]
)
return transform(img_pil).unsqueeze(0)
@spaces.GPU(duration=30)
def infer(img_path, seed=1000, randomize_seed=False, timesteps=25, cfg_scale=7.5, perform_norm_guidance=True, super_res=False):
if randomize_seed:
seed = None
imgs = img_from_path(img_path).to(device)
# Tokenize images once
with get_bf16_context(enable_bf16):
tokens = flextok_model.tokenize(imgs)[0] # 1x256
# Create all token subsequences
subseq_list = [tokens[:,:k_keep].clone() for k_keep in K_KEEP_LIST] # [1x1, 1x2, 1x4, ..., 1x256]
# Detokenize various subsequences in parallel. Batch size is 9.
with get_bf16_context(enable_bf16):
generator = get_generator(seed=seed, device=device)
all_reconst = flextok_model.detokenize(
subseq_list, timesteps=timesteps,
guidance_scale=cfg_scale, perform_norm_guidance=perform_norm_guidance,
generator=generator, verbose=False,
)
# Transform to PIL images
all_images = [
(
TF.to_pil_image(denormalize(reconst_k).clamp(0,1)),
'1 token (2 bytes)' if k_keep == 1 else f'{k_keep} tokens ({2*k_keep} bytes)'
)
for reconst_k, k_keep in zip(all_reconst, K_KEEP_LIST)
]
if super_res:
all_images = [(aura_sr.upscale_4x(img), label) for img, label in all_images]
return all_images
examples = [
'examples/0.png', 'examples/1.png', 'examples/2.png',
'examples/3.png', 'examples/4.png', 'examples/5.png',
]
css="""
#col-container {
margin: 0 auto;
max-width: 1500px;
}
#col-input-container {
margin: 0 auto;
max-width: 400px;
}
#run-button {
margin: 0 auto;
}
#gallery {
aspect-ratio: 1/1 !important;
height: auto !important;
}
"""
with gr.Blocks(css=css, theme=gr.themes.Base()) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(f"""
# FlexTok: Resampling Images into 1D Token Sequences of Flexible Length
""")
with gr.Row():
with gr.Column(elem_id="col-input-container"):
gr.Markdown(f"""
[`Website`](https://flextok.epfl.ch) | [`arXiv`](https://arxiv.org/abs/2502.13967) | [`GitHub`](https://github.com/apple/ml-flextok)
Research demo for: <br>
[**FlexTok: Resampling Images into 1D Token Sequences of Flexible Length**](https://arxiv.org/abs/2502.13967), arXiv 2025 <br>
This demo uses the FlexTok tokenizer to autoencode the given RGB input, using [{MODEL_ID}](https://huggingface.co/{MODEL_ID}), running on *{power_device}*.
The FlexTok encoder produces a 1D sequence of discrete tokens that are ordered in a coarse-to-fine manner.
We show reconstructions from truncated subsequences, using the first 1, 2, 4, 8, ..., 256 tokens.
As you will see, the first tokens capture more high-level semantic content, while subsequent ones add fine-grained detail.
""")
img_path = gr.Image(label='RGB input image', type='filepath')
run_button = gr.Button(f"Autoencode with {MODEL_NAME}", scale=0, elem_id="run-button")
with gr.Accordion("Advanced Settings", open=False):
gr.Markdown(f"""
The FlexTok decoder is a rectified flow model. The following settings control the seed of the initial noise, the number of denoising timesteps,
the guidance scale, and whether to perform [Adaptive Projected Guidance](https://arxiv.org/abs/2410.02416) (we recommend enabling it).
This FlexTok model operates at 256x256 resolution. You can optionally super-resolve the reconstructions to 1024x1024 using
[Aura-SR](https://huggingface.co/fal/AuraSR) for sharper details, whithout changing the underlying reconstructed image too much.
""")
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=1000)
randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
timesteps = gr.Slider(label="Denoising timesteps", minimum=1, maximum=1000, step=1, value=25)
cfg_scale = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=15.0, step=0.1, value=7.5)
perform_norm_guidance = gr.Checkbox(label="Perform Adaptive Projected Guidance", value=True)
super_res = gr.Checkbox(label="Super-resolve reconstructions from 256x256 to 1024x1024 with Aura-SR", value=False)
result = gr.Gallery(
label="Reconstructions", show_label=True, elem_id="gallery", type='pil',
columns=[3], rows=None, object_fit="contain", height=800
)
gr.Examples(
examples = examples,
fn = infer,
inputs = [img_path],
outputs = [result],
cache_examples='lazy',
)
run_button.click(
fn = infer,
inputs = [img_path, seed, randomize_seed, timesteps, cfg_scale, perform_norm_guidance, super_res],
outputs = [result]
)
demo.queue(max_size=10).launch(share=True) |