WensongSong commited on
Commit
60464bd
·
verified ·
1 Parent(s): 47a1b08

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -36,9 +36,10 @@ SAM_ENCODER_VERSION = "vit_h"
36
  SAM_CHECKPOINT_PATH = "./sam_vit_h_4b8939.pth"
37
 
38
  # Building GroundingDINO inference model
39
- groundingdino_model = load_model(model_config_path=GROUNDING_DINO_CONFIG_PATH, model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH, device="cpu")
40
  # Building SAM Model and SAM Predictor
41
  sam = build_sam(checkpoint=SAM_CHECKPOINT_PATH)
 
42
  sam_predictor = SamPredictor(sam)
43
 
44
  def transform_image(image_pil):
@@ -128,7 +129,7 @@ def get_mask(image, label):
128
  sam_predictor.set_image(image)
129
 
130
  transformed_boxes = sam_predictor.transform.apply_boxes_torch(
131
- boxes_filt, image.shape[:2])
132
 
133
  masks, _, _ = sam_predictor.predict_torch(
134
  point_coords=None,
@@ -359,7 +360,7 @@ with gr.Blocks() as demo:
359
  text_prompt = gr.Textbox(label="Label")
360
 
361
  with gr.Column(scale=1):
362
- baseline_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", height=721, columns=1)
363
  with gr.Accordion("Advanced Option", open=True):
364
  seed = gr.Slider(label="Seed", minimum=-1, maximum=999999999, step=1, value=666)
365
  gr.Markdown("### Guidelines")
 
36
  SAM_CHECKPOINT_PATH = "./sam_vit_h_4b8939.pth"
37
 
38
  # Building GroundingDINO inference model
39
+ groundingdino_model = load_model(model_config_path=GROUNDING_DINO_CONFIG_PATH, model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH, device="cuda")
40
  # Building SAM Model and SAM Predictor
41
  sam = build_sam(checkpoint=SAM_CHECKPOINT_PATH)
42
+ sam.to(device="cuda")
43
  sam_predictor = SamPredictor(sam)
44
 
45
  def transform_image(image_pil):
 
129
  sam_predictor.set_image(image)
130
 
131
  transformed_boxes = sam_predictor.transform.apply_boxes_torch(
132
+ boxes_filt, image.shape[:2]).to("cuda")
133
 
134
  masks, _, _ = sam_predictor.predict_torch(
135
  point_coords=None,
 
360
  text_prompt = gr.Textbox(label="Label")
361
 
362
  with gr.Column(scale=1):
363
+ baseline_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", height=800, columns=1)
364
  with gr.Accordion("Advanced Option", open=True):
365
  seed = gr.Slider(label="Seed", minimum=-1, maximum=999999999, step=1, value=666)
366
  gr.Markdown("### Guidelines")