DerekLiu35 commited on
Commit
f0fcc66
·
1 Parent(s): 3bb963b
Files changed (2) hide show
  1. app.py +328 -0
  2. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from diffusers import FluxPipeline, FluxTransformer2DModel
4
+ from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
5
+ from transformers import T5EncoderModel
6
+ from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
7
+ import gc
8
+ import random
9
+ from PIL import Image
10
+ import os
11
+ import time
12
+ import spaces
13
+
14
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
+ print(f"Using device: {DEVICE}")
16
+
17
+ DEFAULT_HEIGHT = 1024
18
+ DEFAULT_WIDTH = 1024
19
+ DEFAULT_GUIDANCE_SCALE = 3.5
20
+ DEFAULT_NUM_INFERENCE_STEPS = 15
21
+ DEFAULT_MAX_SEQUENCE_LENGTH = 512
22
+ GENERATION_SEED = 0 # could use a random number generator to set this, for more variety
23
+ HF_TOKEN = os.environ.get("HF_ACCESS_TOKEN")
24
+
25
+ def clear_gpu_memory(*args):
26
+ allocated_before = torch.cuda.memory_allocated(0) / 1024**3 if DEVICE == "cuda" else 0
27
+ reserved_before = torch.cuda.memory_reserved(0) / 1024**3 if DEVICE == "cuda" else 0
28
+ print(f"Before clearing: Allocated={allocated_before:.2f} GB, Reserved={reserved_before:.2f} GB")
29
+
30
+ deleted_types = []
31
+ for arg in args:
32
+ if arg is not None:
33
+ deleted_types.append(str(type(arg)))
34
+ del arg
35
+
36
+ if deleted_types:
37
+ print(f"Deleted objects of types: {', '.join(deleted_types)}")
38
+ else:
39
+ print("No objects passed to clear_gpu_memory.")
40
+
41
+ gc.collect()
42
+ if DEVICE == "cuda":
43
+ torch.cuda.empty_cache()
44
+
45
+ allocated_after = torch.cuda.memory_allocated(0) / 1024**3 if DEVICE == "cuda" else 0
46
+ reserved_after = torch.cuda.memory_reserved(0) / 1024**3 if DEVICE == "cuda" else 0
47
+ print(f"After clearing: Allocated={allocated_after:.2f} GB, Reserved={reserved_after:.2f} GB")
48
+ print("-" * 20)
49
+
50
+ CACHED_PIPES = {}
51
+ def load_bf16_pipeline():
52
+ """Loads the original FLUX.1-dev pipeline in BF16 precision."""
53
+ print("Loading BF16 pipeline...")
54
+ MODEL_ID = "black-forest-labs/FLUX.1-dev"
55
+ if MODEL_ID in CACHED_PIPES:
56
+ return CACHED_PIPES[MODEL_ID]
57
+ start_time = time.time()
58
+ try:
59
+ pipe = FluxPipeline.from_pretrained(
60
+ MODEL_ID,
61
+ torch_dtype=torch.bfloat16,
62
+ token=HF_TOKEN
63
+ )
64
+ pipe.to(DEVICE)
65
+ # pipe.enable_model_cpu_offload()
66
+ end_time = time.time()
67
+ mem_reserved = torch.cuda.memory_reserved(0)/1024**3 if DEVICE == "cuda" else 0
68
+ print(f"BF16 Pipeline loaded in {end_time - start_time:.2f}s. Memory reserved: {mem_reserved:.2f} GB")
69
+ # CACHED_PIPES[MODEL_ID] = pipe
70
+ return pipe
71
+ except Exception as e:
72
+ print(f"Error loading BF16 pipeline: {e}")
73
+ raise # Re-raise exception to be caught in generate_images
74
+
75
+ def load_bnb_8bit_pipeline():
76
+ """Loads the FLUX.1-dev pipeline with 8-bit quantized components."""
77
+ print("Loading 8-bit BNB pipeline...")
78
+ MODEL_ID = "derekl35/FLUX.1-dev-bnb-8bit"
79
+ if MODEL_ID in CACHED_PIPES:
80
+ return CACHED_PIPES[MODEL_ID]
81
+ start_time = time.time()
82
+ try:
83
+ pipe = FluxPipeline.from_pretrained(
84
+ MODEL_ID,
85
+ torch_dtype=torch.bfloat16
86
+ )
87
+ pipe.to(DEVICE)
88
+ # pipe.enable_model_cpu_offload()
89
+ end_time = time.time()
90
+ mem_reserved = torch.cuda.memory_reserved(0)/1024**3 if DEVICE == "cuda" else 0
91
+ print(f"8-bit BNB pipeline loaded in {end_time - start_time:.2f}s. Memory reserved: {mem_reserved:.2f} GB")
92
+ CACHED_PIPES[MODEL_ID] = pipe
93
+ return pipe
94
+ except Exception as e:
95
+ print(f"Error loading 8-bit BNB pipeline: {e}")
96
+ raise
97
+
98
+ def load_bnb_4bit_pipeline():
99
+ """Loads the FLUX.1-dev pipeline with 4-bit quantized components."""
100
+ print("Loading 4-bit BNB pipeline...")
101
+ MODEL_ID = "derekl35/FLUX.1-dev-nf4"
102
+ if MODEL_ID in CACHED_PIPES:
103
+ return CACHED_PIPES[MODEL_ID]
104
+ start_time = time.time()
105
+ try:
106
+ pipe = FluxPipeline.from_pretrained(
107
+ MODEL_ID,
108
+ torch_dtype=torch.bfloat16
109
+ )
110
+ pipe.to(DEVICE)
111
+ # pipe.enable_model_cpu_offload()
112
+ end_time = time.time()
113
+ mem_reserved = torch.cuda.memory_reserved(0)/1024**3 if DEVICE == "cuda" else 0
114
+ print(f"4-bit BNB pipeline loaded in {end_time - start_time:.2f}s. Memory reserved: {mem_reserved:.2f} GB")
115
+ CACHED_PIPES[MODEL_ID] = pipe
116
+ return pipe
117
+ except Exception as e:
118
+ print(f"4-bit BNB pipeline: {e}")
119
+ raise
120
+
121
+ @spaces.GPU(duration=240)
122
+ def generate_images(prompt, quantization_choice, progress=gr.Progress(track_tqdm=True)):
123
+ """Loads original and selected quantized model, generates one image each, clears memory, shuffles results."""
124
+ if not prompt:
125
+ return None, {}, gr.update(value="Please enter a prompt.", interactive=False), gr.update(choices=[], value=None)
126
+
127
+ if not quantization_choice:
128
+ # Return updates for all outputs to clear them or show warning
129
+ return None, {}, gr.update(value="Please select a quantization method.", interactive=False), gr.update(choices=[], value=None)
130
+
131
+ # Determine which quantized model to load
132
+ if quantization_choice == "8-bit":
133
+ quantized_load_func = load_bnb_8bit_pipeline
134
+ quantized_label = "Quantized (8-bit)"
135
+ elif quantization_choice == "4-bit":
136
+ quantized_load_func = load_bnb_4bit_pipeline
137
+ quantized_label = "Quantized (4-bit)"
138
+ else:
139
+ # Should not happen with Radio choices, but good practice
140
+ return None, {}, gr.update(value="Invalid quantization choice.", interactive=False), gr.update(choices=[], value=None)
141
+
142
+ model_configs = [
143
+ ("Original", load_bf16_pipeline),
144
+ (quantized_label, quantized_load_func), # Use the specific label here
145
+ ]
146
+
147
+ results = []
148
+ pipe_kwargs = {
149
+ "prompt": prompt,
150
+ "height": DEFAULT_HEIGHT,
151
+ "width": DEFAULT_WIDTH,
152
+ "guidance_scale": DEFAULT_GUIDANCE_SCALE,
153
+ "num_inference_steps": DEFAULT_NUM_INFERENCE_STEPS,
154
+ "max_sequence_length": DEFAULT_MAX_SEQUENCE_LENGTH,
155
+ }
156
+
157
+ current_pipe = None # Keep track of the current pipe for cleanup
158
+
159
+ for i, (label, load_func) in enumerate(model_configs):
160
+ progress(i / len(model_configs), desc=f"Loading {label} model...")
161
+ print(f"\n--- Loading {label} Model ---")
162
+ load_start_time = time.time()
163
+ try:
164
+ # Ensure previous pipe is cleared *before* loading the next
165
+ # if current_pipe:
166
+ # print(f"--- Clearing memory before loading {label} Model ---")
167
+ # clear_gpu_memory(current_pipe)
168
+ # current_pipe = None
169
+
170
+ current_pipe = load_func()
171
+ load_end_time = time.time()
172
+ print(f"{label} model loaded in {load_end_time - load_start_time:.2f} seconds.")
173
+
174
+ progress((i + 0.5) / len(model_configs), desc=f"Generating with {label} model...")
175
+ print(f"--- Generating with {label} Model ---")
176
+ gen_start_time = time.time()
177
+ image_list = current_pipe(**pipe_kwargs, generator=torch.manual_seed(GENERATION_SEED)).images
178
+ image = image_list[0]
179
+ gen_end_time = time.time()
180
+ results.append({"label": label, "image": image})
181
+ print(f"--- Finished Generation with {label} Model in {gen_end_time - gen_start_time:.2f} seconds ---")
182
+ mem_reserved = torch.cuda.memory_reserved(0)/1024**3 if DEVICE == "cuda" else 0
183
+ print(f"Memory reserved: {mem_reserved:.2f} GB")
184
+
185
+ except Exception as e:
186
+ print(f"Error during {label} model processing: {e}")
187
+ # Attempt cleanup
188
+ if current_pipe:
189
+ print(f"--- Clearing memory after error with {label} Model ---")
190
+ clear_gpu_memory(current_pipe)
191
+ current_pipe = None
192
+ # Return error state to Gradio - update all outputs
193
+ return None, {}, gr.update(value=f"Error processing {label} model: {e}", interactive=False), gr.update(choices=[], value=None)
194
+
195
+ # No finally block needed here, cleanup happens before next load or after loop
196
+
197
+ # Final cleanup after the loop finishes successfully
198
+ # if current_pipe:
199
+ # print(f"--- Clearing memory after last model ({label}) ---")
200
+ # clear_gpu_memory(current_pipe)
201
+ # current_pipe = None
202
+
203
+ if len(results) != len(model_configs):
204
+ print("Generation did not complete for all models.")
205
+ # Update all outputs
206
+ return None, {}, gr.update(value="Failed to generate images for all model types.", interactive=False), gr.update(choices=[], value=None)
207
+
208
+ # Shuffle the results for display
209
+ shuffled_results = results.copy()
210
+ random.shuffle(shuffled_results)
211
+
212
+ # Create the gallery data: [(image, caption), (image, caption)]
213
+ shuffled_data_for_gallery = [(res["image"], f"Image {i+1}") for i, res in enumerate(shuffled_results)]
214
+
215
+ # Create the mapping: display_index -> correct_label (e.g., {0: 'Original', 1: 'Quantized (8-bit)'})
216
+ correct_mapping = {i: res["label"] for i, res in enumerate(shuffled_results)}
217
+ print("Correct mapping (hidden):", correct_mapping)
218
+
219
+ guess_radio_update = gr.update(choices=["Image 1", "Image 2"], value=None, interactive=True)
220
+
221
+ # Return shuffled images, the correct mapping state, status message, and update the guess radio
222
+ return shuffled_data_for_gallery, correct_mapping, gr.update(value="Generation complete! Make your guess.", interactive=False), guess_radio_update
223
+
224
+
225
+ # --- Guess Verification Function ---
226
+ def check_guess(user_guess, correct_mapping_state):
227
+ """Compares the user's guess with the correct mapping stored in the state."""
228
+
229
+ if not isinstance(correct_mapping_state, dict) or not correct_mapping_state:
230
+ return "Please generate images first (state is empty or invalid)."
231
+
232
+ if user_guess is None:
233
+ return "Please select which image you think is quantized."
234
+
235
+ # Find which display index (0 or 1) corresponds to the quantized image
236
+ quantized_image_index = -1
237
+ quantized_label_actual = ""
238
+ for index, label in correct_mapping_state.items():
239
+ if "Quantized" in label: # Check if the label indicates quantization
240
+ quantized_image_index = index
241
+ quantized_label_actual = label # Store the full label e.g. "Quantized (8-bit)"
242
+ break
243
+
244
+ if quantized_image_index == -1:
245
+ # This shouldn't happen if generation was successful
246
+ return "Error: Could not find the quantized image in the mapping data."
247
+
248
+ # Determine what the user *should* have selected based on the index
249
+ correct_guess_label = f"Image {quantized_image_index + 1}" # "Image 1" or "Image 2"
250
+
251
+ if user_guess == correct_guess_label:
252
+ feedback = f"Correct! {correct_guess_label} used the {quantized_label_actual} model."
253
+ else:
254
+ feedback = f"Incorrect. The quantized image ({quantized_label_actual}) was {correct_guess_label}."
255
+
256
+ return feedback
257
+
258
+
259
+ with gr.Blocks(title="FLUX Quantization Challenge", theme=gr.themes.Soft()) as demo:
260
+ gr.Markdown("# FLUX Model Quantization Challenge")
261
+ gr.Markdown(
262
+ "Compare the original FLUX.1-dev (BF16) model against a quantized version (4-bit or 8-bit). "
263
+ "Enter a prompt, choose the quantization method, and generate two images. "
264
+ "The images will be shuffled. Can you guess which one used quantization?"
265
+ )
266
+
267
+ with gr.Row():
268
+ prompt_input = gr.Textbox(label="Enter Prompt", placeholder="e.g., A photorealistic portrait of an astronaut on Mars", scale=3)
269
+ quantization_choice_radio = gr.Radio(
270
+ choices=["8-bit", "4-bit"],
271
+ label="Select Quantization",
272
+ value="8-bit", # Default choice
273
+ scale=1
274
+ )
275
+ generate_button = gr.Button("Generate & Compare", variant="primary", scale=1)
276
+
277
+ output_gallery = gr.Gallery(
278
+ label="Generated Images (Original vs. Quantized)",
279
+ columns=2,
280
+ height=512,
281
+ object_fit="contain",
282
+ allow_preview=True,
283
+ show_label=True, # Shows "Image 1", "Image 2" captions we provide
284
+ )
285
+
286
+ gr.Markdown("### Which image used the selected quantization method?")
287
+ with gr.Row():
288
+ # Centered guess radio and submit button
289
+ with gr.Column(scale=1): # Dummy column for spacing
290
+ pass
291
+ with gr.Column(scale=2): # Column for the radio button
292
+ guess_radio = gr.Radio(
293
+ choices=[],
294
+ label="Your Guess",
295
+ info="Select the image you believe was generated with the quantized model.",
296
+ interactive=False # Disabled until images are generated
297
+ )
298
+ with gr.Column(scale=1): # Column for the button
299
+ submit_guess_button = gr.Button("Submit Guess")
300
+ with gr.Column(scale=1): # Dummy column for spacing
301
+ pass
302
+
303
+ feedback_box = gr.Textbox(label="Feedback", interactive=False, lines=1)
304
+
305
+ # Hidden state to store the correct mapping after shuffling
306
+ # e.g., {0: 'Original', 1: 'Quantized (8-bit)'} or {0: 'Quantized (4-bit)', 1: 'Original'}
307
+ correct_mapping_state = gr.State({})
308
+
309
+ generate_button.click(
310
+ fn=generate_images,
311
+ inputs=[prompt_input, quantization_choice_radio],
312
+ outputs=[output_gallery, correct_mapping_state, feedback_box, guess_radio]
313
+ ).then(
314
+ lambda: "", # Clear feedback box on new generation
315
+ outputs=[feedback_box]
316
+ )
317
+
318
+
319
+ submit_guess_button.click(
320
+ fn=check_guess,
321
+ inputs=[guess_radio, correct_mapping_state], # Pass the selected guess and the state
322
+ outputs=[feedback_box]
323
+ )
324
+
325
+ if __name__ == "__main__":
326
+ # queue()
327
+ # demo.queue().launch() # Set share=True to create public link if needed
328
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ accelerate
2
+ diffusers
3
+ invisible_watermark
4
+ torch==2.4.0
5
+ transformers
6
+ xformers
7
+ bitsandbytes
8
+ sentencepiece