xinjie.wang commited on
Commit
8e3d188
·
1 Parent(s): 1a190b2
Files changed (2) hide show
  1. app.py +19 -8
  2. common.py +9 -7
app.py CHANGED
@@ -109,6 +109,11 @@ with gr.Blocks(
109
  value=2048,
110
  step=256,
111
  )
 
 
 
 
 
112
  with gr.Row():
113
  randomize_seed = gr.Checkbox(
114
  label="Randomize Seed", value=False
@@ -207,10 +212,13 @@ with gr.Blocks(
207
  with gr.Row() as single_image_example:
208
  examples = gr.Examples(
209
  label="Image Gallery",
210
- examples=sorted(
211
- glob("assets/example_image/*")
212
- ),
213
- inputs=[image_prompt],
 
 
 
214
  fn=preprocess_image_fn,
215
  outputs=[image_prompt, raw_image_cache],
216
  run_on_click=True,
@@ -220,9 +228,12 @@ with gr.Blocks(
220
  with gr.Row(visible=False) as single_sam_image_example:
221
  examples = gr.Examples(
222
  label="Image Gallery",
223
- examples=sorted(
224
- glob("assets/example_image/*")
225
- ),
 
 
 
226
  inputs=[image_prompt_sam],
227
  fn=preprocess_sam_image_fn,
228
  outputs=[image_prompt_sam, raw_image_cache],
@@ -274,7 +285,7 @@ with gr.Blocks(
274
 
275
  image_prompt.upload(
276
  preprocess_image_fn,
277
- inputs=[image_prompt],
278
  outputs=[image_prompt, raw_image_cache],
279
  )
280
  image_prompt.change(
 
109
  value=2048,
110
  step=256,
111
  )
112
+ rmbg_tag = gr.Radio(
113
+ choices=["rembg", "rmbg14"],
114
+ value="rembg",
115
+ label="Background Removal Model",
116
+ )
117
  with gr.Row():
118
  randomize_seed = gr.Checkbox(
119
  label="Randomize Seed", value=False
 
212
  with gr.Row() as single_image_example:
213
  examples = gr.Examples(
214
  label="Image Gallery",
215
+ examples=[
216
+ [image_path]
217
+ for image_path in sorted(
218
+ glob("assets/example_image/*")
219
+ )
220
+ ],
221
+ inputs=[image_prompt, rmbg_tag],
222
  fn=preprocess_image_fn,
223
  outputs=[image_prompt, raw_image_cache],
224
  run_on_click=True,
 
228
  with gr.Row(visible=False) as single_sam_image_example:
229
  examples = gr.Examples(
230
  label="Image Gallery",
231
+ examples=[
232
+ [image_path]
233
+ for image_path in sorted(
234
+ glob("assets/example_image/*")
235
+ )
236
+ ],
237
  inputs=[image_prompt_sam],
238
  fn=preprocess_sam_image_fn,
239
  outputs=[image_prompt_sam, raw_image_cache],
 
285
 
286
  image_prompt.upload(
287
  preprocess_image_fn,
288
+ inputs=[image_prompt, rmbg_tag],
289
  outputs=[image_prompt, raw_image_cache],
290
  )
291
  image_prompt.change(
common.py CHANGED
@@ -150,8 +150,8 @@ def download_kolors_weights() -> None:
150
 
151
 
152
  if os.getenv("GRADIO_APP") == "imageto3d":
153
- # RBG_REMOVER = RembgRemover()
154
- RBG_REMOVER = BMGG14Remover()
155
  SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu")
156
  PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
157
  "JeffreyXiang/TRELLIS-image-large"
@@ -165,8 +165,8 @@ if os.getenv("GRADIO_APP") == "imageto3d":
165
  os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d"
166
  )
167
  elif os.getenv("GRADIO_APP") == "textto3d":
168
- # RBG_REMOVER = RembgRemover()
169
- RBG_REMOVER = BMGG14Remover()
170
  PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
171
  "JeffreyXiang/TRELLIS-image-large"
172
  )
@@ -289,7 +289,7 @@ def render_video(
289
 
290
  @spaces.GPU
291
  def preprocess_image_fn(
292
- image: str | np.ndarray | Image.Image,
293
  ) -> tuple[Image.Image, Image.Image]:
294
  if isinstance(image, str):
295
  image = Image.open(image)
@@ -298,7 +298,8 @@ def preprocess_image_fn(
298
 
299
  image_cache = image.copy().resize((512, 512))
300
 
301
- image = RBG_REMOVER(image)
 
302
  image = trellis_preprocess(image)
303
 
304
  return image, image_cache
@@ -710,6 +711,7 @@ def text2image_fn(
710
  ip_image: Image.Image | str = None,
711
  ip_adapt_scale: float = 0.3,
712
  image_wh: int | tuple[int, int] = [1024, 1024],
 
713
  n_sample: int = 3,
714
  req: gr.Request = None,
715
  ):
@@ -736,7 +738,7 @@ def text2image_fn(
736
 
737
  for idx in range(len(images)):
738
  image = images[idx]
739
- images[idx], _ = preprocess_image_fn(image)
740
 
741
  save_paths = []
742
  for idx, image in enumerate(images):
 
150
 
151
 
152
  if os.getenv("GRADIO_APP") == "imageto3d":
153
+ RBG_REMOVER = RembgRemover()
154
+ RBG14_REMOVER = BMGG14Remover()
155
  SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu")
156
  PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
157
  "JeffreyXiang/TRELLIS-image-large"
 
165
  os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d"
166
  )
167
  elif os.getenv("GRADIO_APP") == "textto3d":
168
+ RBG_REMOVER = RembgRemover()
169
+ RBG14_REMOVER = BMGG14Remover()
170
  PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
171
  "JeffreyXiang/TRELLIS-image-large"
172
  )
 
289
 
290
  @spaces.GPU
291
  def preprocess_image_fn(
292
+ image: str | np.ndarray | Image.Image, rmbg_tag: str = "rembg"
293
  ) -> tuple[Image.Image, Image.Image]:
294
  if isinstance(image, str):
295
  image = Image.open(image)
 
298
 
299
  image_cache = image.copy().resize((512, 512))
300
 
301
+ bg_remover = RBG_REMOVER if rmbg_tag == "rembg" else RBG14_REMOVER
302
+ image = bg_remover(image)
303
  image = trellis_preprocess(image)
304
 
305
  return image, image_cache
 
711
  ip_image: Image.Image | str = None,
712
  ip_adapt_scale: float = 0.3,
713
  image_wh: int | tuple[int, int] = [1024, 1024],
714
+ rmbg_tag: str = "rembg",
715
  n_sample: int = 3,
716
  req: gr.Request = None,
717
  ):
 
738
 
739
  for idx in range(len(images)):
740
  image = images[idx]
741
+ images[idx], _ = preprocess_image_fn(image, rmbg_tag)
742
 
743
  save_paths = []
744
  for idx, image in enumerate(images):