DerekLiu35 commited on
Commit
b64eca0
·
1 Parent(s): 575f433

change guessing to use buttons

Browse files
Files changed (1) hide show
  1. app.py +20 -33
app.py CHANGED
@@ -1,9 +1,6 @@
1
  import torch
2
  import gradio as gr
3
  from diffusers import FluxPipeline, FluxTransformer2DModel
4
- from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
5
- from transformers import T5EncoderModel
6
- from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
7
  import gc
8
  import random
9
  from PIL import Image
@@ -131,6 +128,9 @@ def generate_images(prompt, quantization_choice, progress=gr.Progress(track_tqdm
131
 
132
  current_pipe = None # Keep track of the current pipe for cleanup
133
 
 
 
 
134
  for i, (label, load_func) in enumerate(model_configs):
135
  progress(i / len(model_configs), desc=f"Loading {label} model...")
136
  print(f"\n--- Loading {label} Model ---")
@@ -143,8 +143,9 @@ def generate_images(prompt, quantization_choice, progress=gr.Progress(track_tqdm
143
  progress((i + 0.5) / len(model_configs), desc=f"Generating with {label} model...")
144
  print(f"--- Generating with {label} Model ---")
145
  gen_start_time = time.time()
146
- image_list = current_pipe(**pipe_kwargs, generator=torch.manual_seed(GENERATION_SEED)).images
147
  image = image_list[0]
 
148
  gen_end_time = time.time()
149
  results.append({"label": label, "image": image})
150
  print(f"--- Finished Generation with {label} Model in {gen_end_time - gen_start_time:.2f} seconds ---")
@@ -219,11 +220,11 @@ with gr.Blocks(title="FLUX Quantization Challenge", theme=gr.themes.Soft()) as d
219
  gr.Markdown(
220
  "Compare the original FLUX.1-dev (BF16) model against a quantized version (4-bit or 8-bit). "
221
  "Enter a prompt, choose the quantization method, and generate two images. "
222
- "The images will be shuffled. Can you guess which one used quantization?"
223
  )
224
 
225
  with gr.Row():
226
- prompt_input = gr.Textbox(label="Enter Prompt", placeholder="e.g., A photorealistic portrait of an astronaut on Mars", scale=3)
227
  quantization_choice_radio = gr.Radio(
228
  choices=["8-bit", "4-bit"],
229
  label="Select Quantization",
@@ -243,20 +244,8 @@ with gr.Blocks(title="FLUX Quantization Challenge", theme=gr.themes.Soft()) as d
243
 
244
  gr.Markdown("### Which image used the selected quantization method?")
245
  with gr.Row():
246
- # Centered guess radio and submit button
247
- with gr.Column(scale=1): # Dummy column for spacing
248
- pass
249
- with gr.Column(scale=2): # Column for the radio button
250
- guess_radio = gr.Radio(
251
- choices=[],
252
- label="Your Guess",
253
- info="Select the image you believe was generated with the quantized model.",
254
- interactive=False # Disabled until images are generated
255
- )
256
- with gr.Column(scale=1): # Column for the button
257
- submit_guess_button = gr.Button("Submit Guess")
258
- with gr.Column(scale=1): # Dummy column for spacing
259
- pass
260
 
261
  feedback_box = gr.Textbox(label="Feedback", interactive=False, lines=1)
262
 
@@ -267,20 +256,18 @@ with gr.Blocks(title="FLUX Quantization Challenge", theme=gr.themes.Soft()) as d
267
  generate_button.click(
268
  fn=generate_images,
269
  inputs=[prompt_input, quantization_choice_radio],
270
- outputs=[output_gallery, correct_mapping_state, feedback_box, guess_radio]
271
- ).then(
272
- lambda: "", # Clear feedback box on new generation
273
- outputs=[feedback_box]
274
- )
275
 
 
 
 
 
 
276
 
277
- submit_guess_button.click(
278
- fn=check_guess,
279
- inputs=[guess_radio, correct_mapping_state], # Pass the selected guess and the state
280
- outputs=[feedback_box]
281
- )
282
 
283
  if __name__ == "__main__":
284
- # queue()
285
- # demo.queue().launch()
286
- demo.launch(share=True)
 
1
  import torch
2
  import gradio as gr
3
  from diffusers import FluxPipeline, FluxTransformer2DModel
 
 
 
4
  import gc
5
  import random
6
  from PIL import Image
 
128
 
129
  current_pipe = None # Keep track of the current pipe for cleanup
130
 
131
+ seed = random.getrandbits(64)
132
+ print(f"Using seed: {seed}")
133
+
134
  for i, (label, load_func) in enumerate(model_configs):
135
  progress(i / len(model_configs), desc=f"Loading {label} model...")
136
  print(f"\n--- Loading {label} Model ---")
 
143
  progress((i + 0.5) / len(model_configs), desc=f"Generating with {label} model...")
144
  print(f"--- Generating with {label} Model ---")
145
  gen_start_time = time.time()
146
+ image_list = current_pipe(**pipe_kwargs, generator=torch.manual_seed(seed)).images
147
  image = image_list[0]
148
+ # image.save(f"{load_start_time}.png")
149
  gen_end_time = time.time()
150
  results.append({"label": label, "image": image})
151
  print(f"--- Finished Generation with {label} Model in {gen_end_time - gen_start_time:.2f} seconds ---")
 
220
  gr.Markdown(
221
  "Compare the original FLUX.1-dev (BF16) model against a quantized version (4-bit or 8-bit). "
222
  "Enter a prompt, choose the quantization method, and generate two images. "
223
+ "The images will be shuffled, can you spot which one was quantized?"
224
  )
225
 
226
  with gr.Row():
227
+ prompt_input = gr.Textbox(label="Enter Prompt", scale=3)
228
  quantization_choice_radio = gr.Radio(
229
  choices=["8-bit", "4-bit"],
230
  label="Select Quantization",
 
244
 
245
  gr.Markdown("### Which image used the selected quantization method?")
246
  with gr.Row():
247
+ image1_btn = gr.Button("Image 1")
248
+ image2_btn = gr.Button("Image 2")
 
 
 
 
 
 
 
 
 
 
 
 
249
 
250
  feedback_box = gr.Textbox(label="Feedback", interactive=False, lines=1)
251
 
 
256
  generate_button.click(
257
  fn=generate_images,
258
  inputs=[prompt_input, quantization_choice_radio],
259
+ outputs=[output_gallery, correct_mapping_state],
260
+ ).then(lambda: "", outputs=[feedback_box]) # clear feedback on new run
 
 
 
261
 
262
+ # helper wrappers so we can supply the fixed choice string
263
+ def choose_img1(mapping):
264
+ return check_guess("Image 1", mapping)
265
+ def choose_img2(mapping):
266
+ return check_guess("Image 2", mapping)
267
 
268
+ image1_btn.click(choose_img1, inputs=[correct_mapping_state], outputs=[feedback_box])
269
+ image2_btn.click(choose_img2, inputs=[correct_mapping_state], outputs=[feedback_box])
 
 
 
270
 
271
  if __name__ == "__main__":
272
+ demo.launch(share=True)
273
+ demo.launch()