MohamedRashad commited on
Commit
a5480b1
·
1 Parent(s): cce8702

Update model pipeline and adjust inference steps in generate_item_image function

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -24,10 +24,10 @@ from live_preview_helpers import flux_pipe_call_that_returns_an_iterable_of_imag
24
  llm_client = Client("Qwen/Qwen2.5-72B-Instruct")
25
 
26
  device = "cuda" if torch.cuda.is_available() else "cpu"
27
- pipe = FluxPipeline.from_pretrained("sayakpaul/FLUX.1-merged", torch_dtype=torch.bfloat16).to(device)
28
  pipe.vae.enable_tiling()
29
  pipe.vae.enable_slicing()
30
- pipe.enable_sequential_cpu_offload() # offloads modules to CPU on a submodule level (rather than model level)
31
  torch.cuda.empty_cache()
32
 
33
  pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
@@ -82,13 +82,13 @@ def preprocess_pil_image(image: Image.Image) -> Tuple[str, Image.Image]:
82
  processed_image.save(f"{TMP_DIR}/{trial_id}.png")
83
  return trial_id, processed_image
84
 
85
- @spaces.GPU()
86
  def generate_item_image(object_t2i_prompt):
87
  trial_id = ""
88
  for image in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
89
  prompt=object_t2i_prompt,
90
  guidance_scale=3.5,
91
- num_inference_steps=1,
92
  width=512,
93
  height=512,
94
  generator=torch.Generator("cpu").manual_seed(0),
 
24
  llm_client = Client("Qwen/Qwen2.5-72B-Instruct")
25
 
26
  device = "cuda" if torch.cuda.is_available() else "cpu"
27
+ pipe = FluxPipeline.from_pretrained("Freepik/flux.1-lite-8B-alpha", torch_dtype=torch.bfloat16).to(device)
28
  pipe.vae.enable_tiling()
29
  pipe.vae.enable_slicing()
30
+ # pipe.enable_sequential_cpu_offload() # offloads modules to CPU on a submodule level (rather than model level)
31
  torch.cuda.empty_cache()
32
 
33
  pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
 
82
  processed_image.save(f"{TMP_DIR}/{trial_id}.png")
83
  return trial_id, processed_image
84
 
85
+ @spaces.GPU(duration=75)
86
  def generate_item_image(object_t2i_prompt):
87
  trial_id = ""
88
  for image in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
89
  prompt=object_t2i_prompt,
90
  guidance_scale=3.5,
91
+ num_inference_steps=28,
92
  width=512,
93
  height=512,
94
  generator=torch.Generator("cpu").manual_seed(0),