derekl35 HF Staff commited on
Commit
b66d963
·
verified ·
1 Parent(s): d5446b3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +312 -141
app.py CHANGED
@@ -1,154 +1,325 @@
 
1
  import gradio as gr
2
- import numpy as np
 
 
 
 
3
  import random
 
 
 
4
 
5
- # import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
7
- import torch
8
 
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
-
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
16
-
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
19
-
20
- MAX_SEED = np.iinfo(np.int32).max
21
- MAX_IMAGE_SIZE = 1024
22
-
23
-
24
- # @spaces.GPU #[uncomment to use ZeroGPU]
25
- def infer(
26
- prompt,
27
- negative_prompt,
28
- seed,
29
- randomize_seed,
30
- width,
31
- height,
32
- guidance_scale,
33
- num_inference_steps,
34
- progress=gr.Progress(track_tqdm=True),
35
- ):
36
- if randomize_seed:
37
- seed = random.randint(0, MAX_SEED)
38
-
39
- generator = torch.Generator().manual_seed(seed)
40
-
41
- image = pipe(
42
- prompt=prompt,
43
- negative_prompt=negative_prompt,
44
- guidance_scale=guidance_scale,
45
- num_inference_steps=num_inference_steps,
46
- width=width,
47
- height=height,
48
- generator=generator,
49
- ).images[0]
50
-
51
- return image, seed
52
-
53
-
54
- examples = [
55
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
- "A delicious ceviche cheesecake slice",
58
- ]
59
-
60
- css = """
61
- #col-container {
62
- margin: 0 auto;
63
- max-width: 640px;
64
- }
65
- """
66
-
67
- with gr.Blocks(css=css) as demo:
68
- with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
70
-
71
- with gr.Row():
72
- prompt = gr.Text(
73
- label="Prompt",
74
- show_label=False,
75
- max_lines=1,
76
- placeholder="Enter your prompt",
77
- container=False,
78
- )
79
 
80
- run_button = gr.Button("Run", scale=0, variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
- result = gr.Image(label="Result", show_label=False)
 
 
 
 
 
 
 
83
 
84
- with gr.Accordion("Advanced Settings", open=False):
85
- negative_prompt = gr.Text(
86
- label="Negative prompt",
87
- max_lines=1,
88
- placeholder="Enter a negative prompt",
89
- visible=False,
90
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
- seed = gr.Slider(
93
- label="Seed",
94
- minimum=0,
95
- maximum=MAX_SEED,
96
- step=1,
97
- value=0,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
101
-
102
- with gr.Row():
103
- width = gr.Slider(
104
- label="Width",
105
- minimum=256,
106
- maximum=MAX_IMAGE_SIZE,
107
- step=32,
108
- value=1024, # Replace with defaults that work for your model
109
- )
110
-
111
- height = gr.Slider(
112
- label="Height",
113
- minimum=256,
114
- maximum=MAX_IMAGE_SIZE,
115
- step=32,
116
- value=1024, # Replace with defaults that work for your model
117
- )
118
-
119
- with gr.Row():
120
- guidance_scale = gr.Slider(
121
- label="Guidance scale",
122
- minimum=0.0,
123
- maximum=10.0,
124
- step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
126
- )
127
-
128
- num_inference_steps = gr.Slider(
129
- label="Number of inference steps",
130
- minimum=1,
131
- maximum=50,
132
- step=1,
133
- value=2, # Replace with defaults that work for your model
134
- )
135
-
136
- gr.Examples(examples=examples, inputs=[prompt])
137
- gr.on(
138
- triggers=[run_button.click, prompt.submit],
139
- fn=infer,
140
- inputs=[
141
- prompt,
142
- negative_prompt,
143
- seed,
144
- randomize_seed,
145
- width,
146
- height,
147
- guidance_scale,
148
- num_inference_steps,
149
- ],
150
- outputs=[result, seed],
151
  )
152
 
153
  if __name__ == "__main__":
154
- demo.launch()
 
 
 
1
+ import torch
2
  import gradio as gr
3
+ from diffusers import FluxPipeline, FluxTransformer2DModel
4
+ from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
5
+ from transformers import T5EncoderModel
6
+ from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
7
+ import gc
8
  import random
9
+ from PIL import Image
10
+ import os
11
+ import time
12
 
13
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
+ print(f"Using device: {DEVICE}")
 
15
 
16
+ DEFAULT_HEIGHT = 1024
17
+ DEFAULT_WIDTH = 1024
18
+ DEFAULT_GUIDANCE_SCALE = 3.5
19
+ DEFAULT_NUM_INFERENCE_STEPS = 50
20
+ DEFAULT_MAX_SEQUENCE_LENGTH = 512
21
+ GENERATION_SEED = 0 # could use a random number generator to set this, for more variety
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ def clear_gpu_memory(*args):
24
+ allocated_before = torch.cuda.memory_allocated(0) / 1024**3 if DEVICE == "cuda" else 0
25
+ reserved_before = torch.cuda.memory_reserved(0) / 1024**3 if DEVICE == "cuda" else 0
26
+ print(f"Before clearing: Allocated={allocated_before:.2f} GB, Reserved={reserved_before:.2f} GB")
27
+
28
+ deleted_types = []
29
+ for arg in args:
30
+ if arg is not None:
31
+ deleted_types.append(str(type(arg)))
32
+ del arg
33
+
34
+ if deleted_types:
35
+ print(f"Deleted objects of types: {', '.join(deleted_types)}")
36
+ else:
37
+ print("No objects passed to clear_gpu_memory.")
38
 
39
+ gc.collect()
40
+ if DEVICE == "cuda":
41
+ torch.cuda.empty_cache()
42
+
43
+ allocated_after = torch.cuda.memory_allocated(0) / 1024**3 if DEVICE == "cuda" else 0
44
+ reserved_after = torch.cuda.memory_reserved(0) / 1024**3 if DEVICE == "cuda" else 0
45
+ print(f"After clearing: Allocated={allocated_after:.2f} GB, Reserved={reserved_after:.2f} GB")
46
+ print("-" * 20)
47
 
48
+ CACHED_PIPES = {}
49
+ def load_bf16_pipeline():
50
+ """Loads the original FLUX.1-dev pipeline in BF16 precision."""
51
+ print("Loading BF16 pipeline...")
52
+ MODEL_ID = "black-forest-labs/FLUX.1-dev"
53
+ if MODEL_ID in CACHED_PIPES:
54
+ return CACHED_PIPES[MODEL_ID]
55
+ start_time = time.time()
56
+ try:
57
+ pipe = FluxPipeline.from_pretrained(
58
+ MODEL_ID,
59
+ torch_dtype=torch.bfloat16
60
+ )
61
+ pipe.to(DEVICE)
62
+ # pipe.enable_model_cpu_offload()
63
+ end_time = time.time()
64
+ mem_reserved = torch.cuda.memory_reserved(0)/1024**3 if DEVICE == "cuda" else 0
65
+ print(f"BF16 Pipeline loaded in {end_time - start_time:.2f}s. Memory reserved: {mem_reserved:.2f} GB")
66
+ # CACHED_PIPES[MODEL_ID] = pipe
67
+ return pipe
68
+ except Exception as e:
69
+ print(f"Error loading BF16 pipeline: {e}")
70
+ raise # Re-raise exception to be caught in generate_images
71
+
72
+ def load_bnb_8bit_pipeline():
73
+ """Loads the FLUX.1-dev pipeline with 8-bit quantized components."""
74
+ print("Loading 8-bit BNB pipeline...")
75
+ MODEL_ID = "derekl35/FLUX.1-dev-bnb-8bit"
76
+ if MODEL_ID in CACHED_PIPES:
77
+ return CACHED_PIPES[MODEL_ID]
78
+ start_time = time.time()
79
+ try:
80
+ pipe = FluxPipeline.from_pretrained(
81
+ MODEL_ID,
82
+ torch_dtype=torch.bfloat16
83
+ )
84
+ pipe.to(DEVICE)
85
+ # pipe.enable_model_cpu_offload()
86
+ end_time = time.time()
87
+ mem_reserved = torch.cuda.memory_reserved(0)/1024**3 if DEVICE == "cuda" else 0
88
+ print(f"8-bit BNB pipeline loaded in {end_time - start_time:.2f}s. Memory reserved: {mem_reserved:.2f} GB")
89
+ CACHED_PIPES[MODEL_ID] = pipe
90
+ return pipe
91
+ except Exception as e:
92
+ print(f"Error loading 8-bit BNB pipeline: {e}")
93
+ raise
94
+
95
+ def load_bnb_4bit_pipeline():
96
+ """Loads the FLUX.1-dev pipeline with 4-bit quantized components."""
97
+ print("Loading 4-bit BNB pipeline...")
98
+ MODEL_ID = "derekl35/FLUX.1-dev-nf4"
99
+ if MODEL_ID in CACHED_PIPES:
100
+ return CACHED_PIPES[MODEL_ID]
101
+ start_time = time.time()
102
+ try:
103
+ pipe = FluxPipeline.from_pretrained(
104
+ MODEL_ID,
105
+ torch_dtype=torch.bfloat16
106
+ )
107
+ pipe.to(DEVICE)
108
+ # pipe.enable_model_cpu_offload()
109
+ end_time = time.time()
110
+ mem_reserved = torch.cuda.memory_reserved(0)/1024**3 if DEVICE == "cuda" else 0
111
+ print(f"4-bit BNB pipeline loaded in {end_time - start_time:.2f}s. Memory reserved: {mem_reserved:.2f} GB")
112
+ CACHED_PIPES[MODEL_ID] = pipe
113
+ return pipe
114
+ except Exception as e:
115
+ print(f"4-bit BNB pipeline: {e}")
116
+ raise
117
+
118
+ # --- Image Generation and Shuffling Function ---
119
+ def generate_images(prompt, quantization_choice, progress=gr.Progress(track_tqdm=True)):
120
+ """Loads original and selected quantized model, generates one image each, clears memory, shuffles results."""
121
+ if not prompt:
122
+ return None, {}, gr.update(value="Please enter a prompt.", interactive=False), gr.update(choices=[], value=None)
123
+
124
+ if not quantization_choice:
125
+ # Return updates for all outputs to clear them or show warning
126
+ return None, {}, gr.update(value="Please select a quantization method.", interactive=False), gr.update(choices=[], value=None)
127
+
128
+ # Determine which quantized model to load
129
+ if quantization_choice == "8-bit":
130
+ quantized_load_func = load_bnb_8bit_pipeline
131
+ quantized_label = "Quantized (8-bit)"
132
+ elif quantization_choice == "4-bit":
133
+ quantized_load_func = load_bnb_4bit_pipeline
134
+ quantized_label = "Quantized (4-bit)"
135
+ else:
136
+ # Should not happen with Radio choices, but good practice
137
+ return None, {}, gr.update(value="Invalid quantization choice.", interactive=False), gr.update(choices=[], value=None)
138
+
139
+ model_configs = [
140
+ ("Original", load_bf16_pipeline),
141
+ (quantized_label, quantized_load_func), # Use the specific label here
142
+ ]
143
+
144
+ results = []
145
+ pipe_kwargs = {
146
+ "prompt": prompt,
147
+ "height": DEFAULT_HEIGHT,
148
+ "width": DEFAULT_WIDTH,
149
+ "guidance_scale": DEFAULT_GUIDANCE_SCALE,
150
+ "num_inference_steps": DEFAULT_NUM_INFERENCE_STEPS,
151
+ "max_sequence_length": DEFAULT_MAX_SEQUENCE_LENGTH,
152
+ }
153
+
154
+ current_pipe = None # Keep track of the current pipe for cleanup
155
+
156
+ for i, (label, load_func) in enumerate(model_configs):
157
+ progress(i / len(model_configs), desc=f"Loading {label} model...")
158
+ print(f"\n--- Loading {label} Model ---")
159
+ load_start_time = time.time()
160
+ try:
161
+ # Ensure previous pipe is cleared *before* loading the next
162
+ # if current_pipe:
163
+ # print(f"--- Clearing memory before loading {label} Model ---")
164
+ # clear_gpu_memory(current_pipe)
165
+ # current_pipe = None
166
+
167
+ current_pipe = load_func()
168
+ load_end_time = time.time()
169
+ print(f"{label} model loaded in {load_end_time - load_start_time:.2f} seconds.")
170
+
171
+ progress((i + 0.5) / len(model_configs), desc=f"Generating with {label} model...")
172
+ print(f"--- Generating with {label} Model ---")
173
+ gen_start_time = time.time()
174
+ image_list = current_pipe(**pipe_kwargs, generator=torch.manual_seed(GENERATION_SEED)).images
175
+ image = image_list[0]
176
+ gen_end_time = time.time()
177
+ results.append({"label": label, "image": image})
178
+ print(f"--- Finished Generation with {label} Model in {gen_end_time - gen_start_time:.2f} seconds ---")
179
+ mem_reserved = torch.cuda.memory_reserved(0)/1024**3 if DEVICE == "cuda" else 0
180
+ print(f"Memory reserved: {mem_reserved:.2f} GB")
181
+
182
+ except Exception as e:
183
+ print(f"Error during {label} model processing: {e}")
184
+ # Attempt cleanup
185
+ if current_pipe:
186
+ print(f"--- Clearing memory after error with {label} Model ---")
187
+ clear_gpu_memory(current_pipe)
188
+ current_pipe = None
189
+ # Return error state to Gradio - update all outputs
190
+ return None, {}, gr.update(value=f"Error processing {label} model: {e}", interactive=False), gr.update(choices=[], value=None)
191
+
192
+ # No finally block needed here, cleanup happens before next load or after loop
193
+
194
+ # Final cleanup after the loop finishes successfully
195
+ # if current_pipe:
196
+ # print(f"--- Clearing memory after last model ({label}) ---")
197
+ # clear_gpu_memory(current_pipe)
198
+ # current_pipe = None
199
+
200
+ if len(results) != len(model_configs):
201
+ print("Generation did not complete for all models.")
202
+ # Update all outputs
203
+ return None, {}, gr.update(value="Failed to generate images for all model types.", interactive=False), gr.update(choices=[], value=None)
204
+
205
+ # Shuffle the results for display
206
+ shuffled_results = results.copy()
207
+ random.shuffle(shuffled_results)
208
+
209
+ # Create the gallery data: [(image, caption), (image, caption)]
210
+ shuffled_data_for_gallery = [(res["image"], f"Image {i+1}") for i, res in enumerate(shuffled_results)]
211
+
212
+ # Create the mapping: display_index -> correct_label (e.g., {0: 'Original', 1: 'Quantized (8-bit)'})
213
+ correct_mapping = {i: res["label"] for i, res in enumerate(shuffled_results)}
214
+ print("Correct mapping (hidden):", correct_mapping)
215
 
216
+ guess_radio_update = gr.update(choices=["Image 1", "Image 2"], value=None, interactive=True)
217
+
218
+ # Return shuffled images, the correct mapping state, status message, and update the guess radio
219
+ return shuffled_data_for_gallery, correct_mapping, gr.update(value="Generation complete! Make your guess.", interactive=False), guess_radio_update
220
+
221
+
222
+ # --- Guess Verification Function ---
223
+ def check_guess(user_guess, correct_mapping_state):
224
+ """Compares the user's guess with the correct mapping stored in the state."""
225
+
226
+ if not isinstance(correct_mapping_state, dict) or not correct_mapping_state:
227
+ return "Please generate images first (state is empty or invalid)."
228
+
229
+ if user_guess is None:
230
+ return "Please select which image you think is quantized."
231
+
232
+ # Find which display index (0 or 1) corresponds to the quantized image
233
+ quantized_image_index = -1
234
+ quantized_label_actual = ""
235
+ for index, label in correct_mapping_state.items():
236
+ if "Quantized" in label: # Check if the label indicates quantization
237
+ quantized_image_index = index
238
+ quantized_label_actual = label # Store the full label e.g. "Quantized (8-bit)"
239
+ break
240
+
241
+ if quantized_image_index == -1:
242
+ # This shouldn't happen if generation was successful
243
+ return "Error: Could not find the quantized image in the mapping data."
244
+
245
+ # Determine what the user *should* have selected based on the index
246
+ correct_guess_label = f"Image {quantized_image_index + 1}" # "Image 1" or "Image 2"
247
+
248
+ if user_guess == correct_guess_label:
249
+ feedback = f"Correct! {correct_guess_label} used the {quantized_label_actual} model."
250
+ else:
251
+ feedback = f"Incorrect. The quantized image ({quantized_label_actual}) was {correct_guess_label}."
252
+
253
+ return feedback
254
+
255
+
256
+ with gr.Blocks(title="FLUX Quantization Challenge", theme=gr.themes.Soft()) as demo:
257
+ gr.Markdown("# FLUX Model Quantization Challenge")
258
+ gr.Markdown(
259
+ "Compare the original FLUX.1-dev (BF16) model against a quantized version (4-bit or 8-bit). "
260
+ "Enter a prompt, choose the quantization method, and generate two images. "
261
+ "The images will be shuffled. Can you guess which one used quantization?"
262
+ )
263
+
264
+ with gr.Row():
265
+ prompt_input = gr.Textbox(label="Enter Prompt", placeholder="e.g., A photorealistic portrait of an astronaut on Mars", scale=3)
266
+ quantization_choice_radio = gr.Radio(
267
+ choices=["8-bit", "4-bit"],
268
+ label="Select Quantization",
269
+ value="8-bit", # Default choice
270
+ scale=1
271
+ )
272
+ generate_button = gr.Button("Generate & Compare", variant="primary", scale=1)
273
+
274
+ output_gallery = gr.Gallery(
275
+ label="Generated Images (Original vs. Quantized)",
276
+ columns=2,
277
+ height=512,
278
+ object_fit="contain",
279
+ allow_preview=True,
280
+ show_label=True, # Shows "Image 1", "Image 2" captions we provide
281
+ )
282
+
283
+ gr.Markdown("### Which image used the selected quantization method?")
284
+ with gr.Row():
285
+ # Centered guess radio and submit button
286
+ with gr.Column(scale=1): # Dummy column for spacing
287
+ pass
288
+ with gr.Column(scale=2): # Column for the radio button
289
+ guess_radio = gr.Radio(
290
+ choices=[],
291
+ label="Your Guess",
292
+ info="Select the image you believe was generated with the quantized model.",
293
+ interactive=False # Disabled until images are generated
294
  )
295
+ with gr.Column(scale=1): # Column for the button
296
+ submit_guess_button = gr.Button("Submit Guess")
297
+ with gr.Column(scale=1): # Dummy column for spacing
298
+ pass
299
+
300
+ feedback_box = gr.Textbox(label="Feedback", interactive=False, lines=1)
301
+
302
+ # Hidden state to store the correct mapping after shuffling
303
+ # e.g., {0: 'Original', 1: 'Quantized (8-bit)'} or {0: 'Quantized (4-bit)', 1: 'Original'}
304
+ correct_mapping_state = gr.State({})
305
+
306
+ generate_button.click(
307
+ fn=generate_images,
308
+ inputs=[prompt_input, quantization_choice_radio],
309
+ outputs=[output_gallery, correct_mapping_state, feedback_box, guess_radio]
310
+ ).then(
311
+ lambda: "", # Clear feedback box on new generation
312
+ outputs=[feedback_box]
313
+ )
314
+
315
 
316
+ submit_guess_button.click(
317
+ fn=check_guess,
318
+ inputs=[guess_radio, correct_mapping_state], # Pass the selected guess and the state
319
+ outputs=[feedback_box]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
  )
321
 
322
  if __name__ == "__main__":
323
+ # queue()
324
+ # demo.queue().launch() # Set share=True to create public link if needed
325
+ demo.launch()