roman-bachmann commited on
Commit
a4f1fc6
·
1 Parent(s): 6862991

Initial commit

Browse files
Files changed (11) hide show
  1. .gitattributes +1 -0
  2. .gitignore +4 -0
  3. README.md +3 -1
  4. app.py +150 -123
  5. examples/0.png +3 -0
  6. examples/1.png +3 -0
  7. examples/2.png +3 -0
  8. examples/3.png +3 -0
  9. examples/4.png +3 -0
  10. examples/5.png +3 -0
  11. requirements.txt +2 -6
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ .gradio/
2
+ .ipynb_checkpoints/
3
+ .DS_Store
4
+ __pycache__/
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Text-to-Image Gradio Template
3
  emoji: 🖼
4
  colorFrom: purple
5
  colorTo: red
@@ -7,6 +7,8 @@ sdk: gradio
7
  sdk_version: 5.0.1
8
  app_file: app.py
9
  pinned: false
 
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: FlexTok
3
  emoji: 🖼
4
  colorFrom: purple
5
  colorTo: red
 
7
  sdk_version: 5.0.1
8
  app_file: app.py
9
  pinned: false
10
+ license: apache-2.0
11
+ short_description: FlexTok flexible sequence length autoencoding demo
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,154 +1,181 @@
 
 
 
1
  import gradio as gr
2
- import numpy as np
3
  import random
 
 
4
 
5
- # import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
7
  import torch
 
 
8
 
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
 
11
 
 
 
12
  if torch.cuda.is_available():
13
- torch_dtype = torch.float16
 
 
 
 
 
14
  else:
15
- torch_dtype = torch.float32
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
19
 
20
  MAX_SEED = np.iinfo(np.int32).max
21
- MAX_IMAGE_SIZE = 1024
22
-
23
-
24
- # @spaces.GPU #[uncomment to use ZeroGPU]
25
- def infer(
26
- prompt,
27
- negative_prompt,
28
- seed,
29
- randomize_seed,
30
- width,
31
- height,
32
- guidance_scale,
33
- num_inference_steps,
34
- progress=gr.Progress(track_tqdm=True),
35
- ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  if randomize_seed:
37
- seed = random.randint(0, MAX_SEED)
 
 
 
 
 
 
38
 
39
- generator = torch.Generator().manual_seed(seed)
 
 
 
40
 
41
- image = pipe(
42
- prompt=prompt,
43
- negative_prompt=negative_prompt,
44
- guidance_scale=guidance_scale,
45
- num_inference_steps=num_inference_steps,
46
- width=width,
47
- height=height,
48
- generator=generator,
49
- ).images[0]
 
 
 
 
 
50
 
51
- return image, seed
52
 
53
 
54
  examples = [
55
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
- "A delicious ceviche cheesecake slice",
58
  ]
59
 
60
- css = """
61
  #col-container {
62
  margin: 0 auto;
63
- max-width: 640px;
 
 
 
 
 
 
 
 
 
 
 
64
  }
65
  """
66
 
67
- with gr.Blocks(css=css) as demo:
 
68
  with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
70
-
 
 
71
  with gr.Row():
72
- prompt = gr.Text(
73
- label="Prompt",
74
- show_label=False,
75
- max_lines=1,
76
- placeholder="Enter your prompt",
77
- container=False,
78
- )
79
-
80
- run_button = gr.Button("Run", scale=0, variant="primary")
81
-
82
- result = gr.Image(label="Result", show_label=False)
83
-
84
- with gr.Accordion("Advanced Settings", open=False):
85
- negative_prompt = gr.Text(
86
- label="Negative prompt",
87
- max_lines=1,
88
- placeholder="Enter a negative prompt",
89
- visible=False,
90
- )
91
-
92
- seed = gr.Slider(
93
- label="Seed",
94
- minimum=0,
95
- maximum=MAX_SEED,
96
- step=1,
97
- value=0,
 
98
  )
99
 
100
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
101
-
102
- with gr.Row():
103
- width = gr.Slider(
104
- label="Width",
105
- minimum=256,
106
- maximum=MAX_IMAGE_SIZE,
107
- step=32,
108
- value=1024, # Replace with defaults that work for your model
109
- )
110
-
111
- height = gr.Slider(
112
- label="Height",
113
- minimum=256,
114
- maximum=MAX_IMAGE_SIZE,
115
- step=32,
116
- value=1024, # Replace with defaults that work for your model
117
- )
118
-
119
- with gr.Row():
120
- guidance_scale = gr.Slider(
121
- label="Guidance scale",
122
- minimum=0.0,
123
- maximum=10.0,
124
- step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
126
- )
127
-
128
- num_inference_steps = gr.Slider(
129
- label="Number of inference steps",
130
- minimum=1,
131
- maximum=50,
132
- step=1,
133
- value=2, # Replace with defaults that work for your model
134
- )
135
-
136
- gr.Examples(examples=examples, inputs=[prompt])
137
- gr.on(
138
- triggers=[run_button.click, prompt.submit],
139
- fn=infer,
140
- inputs=[
141
- prompt,
142
- negative_prompt,
143
- seed,
144
- randomize_seed,
145
- width,
146
- height,
147
- guidance_scale,
148
- num_inference_steps,
149
- ],
150
- outputs=[result, seed],
151
  )
152
 
153
- if __name__ == "__main__":
154
- demo.launch()
 
1
+ from typing import List
2
+ import os
3
+ import spaces
4
  import gradio as gr
 
5
  import random
6
+ from PIL import Image
7
+ import matplotlib.pyplot as plt
8
 
9
+ import einops
10
+ import numpy as np
11
  import torch
12
+ from torchvision import transforms
13
+ import torchvision.transforms.functional as TF
14
 
15
+ from flextok.flextok_wrapper import FlexTokFromHub
16
+ from flextok.utils.demo import imgs_from_urls, denormalize, batch_to_pil
17
+ from flextok.utils.misc import detect_bf16_support, get_bf16_context, get_generator
18
 
19
+
20
+ # We recommend running this demo on an A100 GPU
21
  if torch.cuda.is_available():
22
+ device = "cuda"
23
+ gpu_type = torch.cuda.get_device_name(torch.cuda.current_device())
24
+ power_device = f"{gpu_type}"
25
+ torch.cuda.max_memory_allocated(device=device)
26
+ # Detect if bf16 is enabled or not
27
+ enable_bf16 = detect_bf16_support()
28
  else:
29
+ device, power_device, enable_bf16 = "cpu", "CPU", False
30
+ print(f'Device: {device}, GPU type: {gpu_type}')
31
+ print('BF16 enabled:', enable_bf16)
32
+
33
+
34
+ # The flag below controls whether to allow TF32 on matmul. This flag defaults to False in PyTorch 1.12 and later.
35
+ torch.backends.cuda.matmul.allow_tf32 = True
36
+ # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
37
+ torch.backends.cudnn.allow_tf32 = True
38
+
39
+ # Global no_grad
40
+ torch.set_grad_enabled(False)
41
 
 
 
42
 
43
  MAX_SEED = np.iinfo(np.int32).max
44
+
45
+ MODEL_ID = 'EPFL-VILAB/flextok_d18_d28_dfn'
46
+ MODEL_NAME = 'FlexTok d18-d28 (DFN)'
47
+
48
+ # Load FlexTok model from HF Hub
49
+ flextok_model = FlexTokFromHub.from_pretrained(MODEL_ID).to(device).eval()
50
+
51
+
52
+ def img_from_path(
53
+ path: str,
54
+ img_size: int = 256,
55
+ mean: List[float] = [0.5, 0.5, 0.5],
56
+ std: List[float] = [0.5, 0.5, 0.5],
57
+ ) -> torch.Tensor:
58
+ # Image loading helper function
59
+ img_pil = Image.open(path).convert("RGB")
60
+
61
+ transform = transforms.Compose(
62
+ [
63
+ transforms.Resize(img_size),
64
+ transforms.CenterCrop(img_size),
65
+ transforms.ToTensor(),
66
+ transforms.Normalize(mean=mean, std=std),
67
+ ]
68
+ )
69
+
70
+ return transform(img_pil).unsqueeze(0)
71
+
72
+
73
+ @spaces.GPU(duration=20)
74
+ def infer(img_path, seed=0, randomize_seed=False, timesteps=20, cfg_scale=7.5, perform_norm_guidance=True):
75
  if randomize_seed:
76
+ seed = None
77
+
78
+ imgs = img_from_path(img_path).to(device)
79
+
80
+ # Tokenize images once
81
+ with get_bf16_context(enable_bf16):
82
+ tokens_list = flextok_model.tokenize(imgs)
83
 
84
+ # Create all token subsequences
85
+ k_keep_list = [1, 2, 4, 8, 16, 32, 64, 128, 256]
86
+ tokens_list = tokens_list*len(k_keep_list)
87
+ subseq_list = [seq[:,:k_keep].clone() for seq, k_keep in zip(tokens_list, k_keep_list)]
88
 
89
+ # Detokenize various subsequences in parallel. Batch size is 9.
90
+ with get_bf16_context(enable_bf16):
91
+ generator = get_generator(seed=seed, device=device)
92
+ all_reconst = flextok_model.detokenize(
93
+ subseq_list, timesteps=timesteps,
94
+ guidance_scale=cfg_scale, perform_norm_guidance=perform_norm_guidance,
95
+ generator=generator, verbose=False,
96
+ )
97
+
98
+ # Transform to PIL images
99
+ all_images = [
100
+ (TF.to_pil_image(denormalize(reconst_k).clamp(0,1)), f'{k_keep} tokens')
101
+ for reconst_k, k_keep in zip(all_reconst, k_keep_list)
102
+ ]
103
 
104
+ return all_images
105
 
106
 
107
  examples = [
108
+ 'examples/0.png', 'examples/1.png', 'examples/2.png',
109
+ 'examples/3.png', 'examples/4.png', 'examples/5.png',
 
110
  ]
111
 
112
+ css="""
113
  #col-container {
114
  margin: 0 auto;
115
+ max-width: 1500px;
116
+ }
117
+ #col-input-container {
118
+ margin: 0 auto;
119
+ max-width: 400px;
120
+ }
121
+ #run-button {
122
+ margin: 0 auto;
123
+ }
124
+ #gallery {
125
+ aspect-ratio: 1/1 !important;
126
+ height: auto !important;
127
  }
128
  """
129
 
130
+ with gr.Blocks(css=css, theme=gr.themes.Base()) as demo:
131
+
132
  with gr.Column(elem_id="col-container"):
133
+ gr.Markdown(f"""
134
+ # FlexTok: Resampling Images into 1D Token Sequences of Flexible Length
135
+ """)
136
+
137
  with gr.Row():
138
+ with gr.Column(elem_id="col-input-container"):
139
+ gr.Markdown(f"""
140
+ [`Website`](https://flextok.epfl.ch) | [`arXiv`](https://arxiv.org/abs/2502.13967) | [`GitHub`](https://github.com/apple/ml-flextok)
141
+
142
+ Official demo for: <br>
143
+ [**FlexTok: Resampling Images into 1D Token Sequences of Flexible Length**](https://arxiv.org/abs/2502.13967), arXiv 2025 <br>
144
+ *[Roman Bachmann](https://roman-bachmann.github.io/)\*, [Jesse Allardice](https://github.com/JesseAllardice)\*, [David Mizrahi](https://dmizrahi.com/)\*, [Enrico Fini](https://scholar.google.com/citations?user=OQMtSKIAAAAJ), [Oğuzhan Fatih Kar](https://ofkar.github.io/), [Elmira Amirloo](https://elamirloo.github.io/), [Alaaeldin El-Nouby](https://aelnouby.github.io/), [Amir Zamir](https://vilab.epfl.ch/zamir/), [Afshin Dehghan](https://scholar.google.com/citations?user=wcX-UW4AAAAJ)*
145
+
146
+ This demo uses the FlexTok tokenizer to autoencode the given RGB input, using [{MODEL_ID}](https://huggingface.co/{MODEL_ID}), running on *{power_device}*. The FlexTok encoder produces a 1D sequence of discrete tokens that are ordered in a coarse-to-fine manner. We show reconstructions from truncated subsequences, using the first 1, 2, 4, 8, ..., 256 tokens. As you will see, the first tokens capture the high-level semantic content, while subsequent ones add more fine-grained detail.
147
+ """)
148
+
149
+ img_path = gr.Image(label='RGB input image', type='filepath')
150
+ run_button = gr.Button(f"Autoencode with {MODEL_NAME}", scale=0, elem_id="run-button")
151
+
152
+ with gr.Accordion("Advanced Settings", open=False):
153
+ gr.Markdown(f"""
154
+ The FlexTok decoder is a rectified flow model. The following settings control the seed of the initial noise, the number of denoising timesteps, the guidance scale, and whether to perform [Adaptive Projected Guidance](https://arxiv.org/abs/2410.02416) (we recommend enabling it).
155
+ """)
156
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
157
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
158
+ timesteps = gr.Slider(label="Denoising timesteps", minimum=1, maximum=1000, step=1, value=20)
159
+ cfg_scale = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=15.0, step=0.1, value=7.5)
160
+ perform_norm_guidance = gr.Checkbox(label="Perform Adaptive Projected Guidance", value=True)
161
+
162
+ result = gr.Gallery(
163
+ label="Reconstructions", show_label=True, elem_id="gallery", type='pil',
164
+ columns=[3], rows=None, object_fit="contain", height=800
165
  )
166
 
167
+ gr.Examples(
168
+ examples = examples,
169
+ fn = infer,
170
+ inputs = [img_path],
171
+ outputs = [result],
172
+ cache_examples='lazy',
173
+ )
174
+
175
+ run_button.click(
176
+ fn = infer,
177
+ inputs = [img_path, seed, randomize_seed, timesteps, cfg_scale, perform_norm_guidance],
178
+ outputs = [result]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  )
180
 
181
+ demo.queue(max_size=10).launch(share=True)
 
examples/0.png ADDED

Git LFS Details

  • SHA256: 74353b65f4ea0594331634dbf110bf5acdf4b8d78ea5061803bb96e0b8f8b502
  • Pointer size: 132 Bytes
  • Size of remote file: 2.06 MB
examples/1.png ADDED

Git LFS Details

  • SHA256: b2a91403a665742918be170c789580e48bf2b11bf7429979cc459ae620f9daee
  • Pointer size: 132 Bytes
  • Size of remote file: 1.49 MB
examples/2.png ADDED

Git LFS Details

  • SHA256: 4031cb4b2b1c43cc1d1d8ea8506a629666b706885ac9b83cb5140f3de286d489
  • Pointer size: 131 Bytes
  • Size of remote file: 321 kB
examples/3.png ADDED

Git LFS Details

  • SHA256: 52bfc2039c5aa175eb2f6a2ab39f259506afa139e8367aca86433d06b8d579ff
  • Pointer size: 132 Bytes
  • Size of remote file: 1.93 MB
examples/4.png ADDED

Git LFS Details

  • SHA256: a5a55b11441fa0622c127f759a60f8cf1876d32fd704b3aeb47af1dc7a5c2b09
  • Pointer size: 132 Bytes
  • Size of remote file: 1.43 MB
examples/5.png ADDED

Git LFS Details

  • SHA256: a2c775a6df577c2af4d3ef98782be956423947461719191a630f1888598353c9
  • Pointer size: 132 Bytes
  • Size of remote file: 1.35 MB
requirements.txt CHANGED
@@ -1,6 +1,2 @@
1
- accelerate
2
- diffusers
3
- invisible_watermark
4
- torch
5
- transformers
6
- xformers
 
1
+ flextok @ git+https://github.com/apple/ml-flextok #@e115399
2
+ spaces