xinjie.wang commited on
Commit
a75fc8a
·
1 Parent(s): eae4507
Files changed (2) hide show
  1. asset3d_gen/models/segment_model.py +30 -2
  2. common.py +5 -2
asset3d_gen/models/segment_model.py CHANGED
@@ -6,6 +6,7 @@ import cv2
6
  import numpy as np
7
  import rembg
8
  import torch
 
9
  from huggingface_hub import snapshot_download
10
  from PIL import Image
11
  from segment_anything import (
@@ -292,6 +293,30 @@ class RembgRemover(object):
292
  return output_image
293
 
294
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
  def invert_rgba_pil(
296
  image: Image.Image, mask: Image.Image, save_path: str = None
297
  ) -> Image.Image:
@@ -368,8 +393,11 @@ if __name__ == "__main__":
368
  model_type="vit_h",
369
  )
370
  remover = RembgRemover()
371
- # clean_image = remover(input_image)
372
- # clean_image.save(output_image)
373
  get_segmented_image(
374
  Image.open(input_image), remover, remover, None, "./test_seg.png"
375
  )
 
 
 
 
6
  import numpy as np
7
  import rembg
8
  import torch
9
+ from transformers import pipeline
10
  from huggingface_hub import snapshot_download
11
  from PIL import Image
12
  from segment_anything import (
 
293
  return output_image
294
 
295
 
296
+ class BMGG14Remover(object):
297
+ def __init__(self) -> None:
298
+ self.model = pipeline(
299
+ "image-segmentation", model="briaai/RMBG-1.4", trust_remote_code=True
300
+ )
301
+
302
+ def __call__(
303
+ self, image: Union[str, Image.Image, np.ndarray], save_path: str = None
304
+ ):
305
+ if isinstance(image, str):
306
+ image = Image.open(image)
307
+ elif isinstance(image, np.ndarray):
308
+ image = Image.fromarray(image)
309
+
310
+ image = resize_pil(image)
311
+ output_image = self.model(image)
312
+
313
+ if save_path is not None:
314
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
315
+ output_image.save(save_path)
316
+
317
+ return output_image
318
+
319
+
320
  def invert_rgba_pil(
321
  image: Image.Image, mask: Image.Image, save_path: str = None
322
  ) -> Image.Image:
 
393
  model_type="vit_h",
394
  )
395
  remover = RembgRemover()
396
+ clean_image = remover(input_image)
397
+ clean_image.save(output_image)
398
  get_segmented_image(
399
  Image.open(input_image), remover, remover, None, "./test_seg.png"
400
  )
401
+
402
+ remover = BMGG14Remover()
403
+ remover("asset3d_gen/models/test_seg.jpg", "./seg.png")
common.py CHANGED
@@ -24,6 +24,7 @@ from asset3d_gen.models.segment_model import (
24
  RembgRemover,
25
  SAMPredictor,
26
  trellis_preprocess,
 
27
  )
28
  from asset3d_gen.models.sr_model import ImageRealESRGAN
29
  from asset3d_gen.scripts.render_gs import entrypoint as render_gs_api
@@ -149,7 +150,8 @@ def download_kolors_weights() -> None:
149
 
150
 
151
  if os.getenv("GRADIO_APP") == "imageto3d":
152
- RBG_REMOVER = RembgRemover()
 
153
  SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu")
154
  PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
155
  "JeffreyXiang/TRELLIS-image-large"
@@ -163,7 +165,8 @@ if os.getenv("GRADIO_APP") == "imageto3d":
163
  os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d"
164
  )
165
  elif os.getenv("GRADIO_APP") == "textto3d":
166
- RBG_REMOVER = RembgRemover()
 
167
  PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
168
  "JeffreyXiang/TRELLIS-image-large"
169
  )
 
24
  RembgRemover,
25
  SAMPredictor,
26
  trellis_preprocess,
27
+ BMGG14Remover,
28
  )
29
  from asset3d_gen.models.sr_model import ImageRealESRGAN
30
  from asset3d_gen.scripts.render_gs import entrypoint as render_gs_api
 
150
 
151
 
152
  if os.getenv("GRADIO_APP") == "imageto3d":
153
+ # RBG_REMOVER = RembgRemover()
154
+ RBG_REMOVER = BMGG14Remover()
155
  SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu")
156
  PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
157
  "JeffreyXiang/TRELLIS-image-large"
 
165
  os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d"
166
  )
167
  elif os.getenv("GRADIO_APP") == "textto3d":
168
+ # RBG_REMOVER = RembgRemover()
169
+ RBG_REMOVER = BMGG14Remover()
170
  PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
171
  "JeffreyXiang/TRELLIS-image-large"
172
  )