LPX55 commited on
Commit
8602a2c
·
1 Parent(s): f7e52d6

no shrooms

Browse files
Files changed (1) hide show
  1. sam2_mask.py +13 -8
sam2_mask.py CHANGED
@@ -12,10 +12,9 @@ MODEL = "facebook/sam2.1-hiera-large"
12
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
 
14
  @spaces.GPU()
 
15
  def predict_masks(image, points):
16
  """Predict a single mask from the image based on selected points."""
17
- global PREDICTOR
18
- PREDICTOR = SAM2ImagePredictor.from_pretrained(MODEL, device=DEVICE)
19
  image_np = np.array(image)
20
  points_list = [[point["x"], point["y"]] for point in points]
21
  input_labels = [1] * len(points_list)
@@ -37,8 +36,15 @@ def predict_masks(image, points):
37
  else:
38
  return image_np
39
 
 
40
  def create_sam2_mask_interface():
41
  """Create the Gradio interface for SAM2 mask generation."""
 
 
 
 
 
 
42
  with gr.Blocks() as sam2_mask_tab:
43
  gr.Markdown("# Object Segmentation with SAM2")
44
  gr.Markdown(
@@ -49,15 +55,14 @@ def create_sam2_mask_interface():
49
  with gr.Row():
50
  with gr.Column():
51
  upload_image_input = ImagePrompter(show_label=False)
52
- submit_button = gr.Button("Submit")
53
  with gr.Column():
54
  image_output = gr.Image(label="Segmented Image", type="pil", height=400)
55
 
56
- # Define the action triggered by the submit button
57
- submit_button.click(
58
- fn=predict_masks,
59
- inputs=[upload_image_input.image, upload_image_input.points],
60
- outputs=image_output,
61
  show_progress=True,
62
  )
63
 
 
12
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
 
14
  @spaces.GPU()
15
+
16
  def predict_masks(image, points):
17
  """Predict a single mask from the image based on selected points."""
 
 
18
  image_np = np.array(image)
19
  points_list = [[point["x"], point["y"]] for point in points]
20
  input_labels = [1] * len(points_list)
 
36
  else:
37
  return image_np
38
 
39
+
40
  def create_sam2_mask_interface():
41
  """Create the Gradio interface for SAM2 mask generation."""
42
+ def update_mask(prompts):
43
+ """Update the mask based on the prompts."""
44
+ image = prompts["image"]
45
+ points = prompts["points"]
46
+ return predict_masks(image, points)
47
+
48
  with gr.Blocks() as sam2_mask_tab:
49
  gr.Markdown("# Object Segmentation with SAM2")
50
  gr.Markdown(
 
55
  with gr.Row():
56
  with gr.Column():
57
  upload_image_input = ImagePrompter(show_label=False)
 
58
  with gr.Column():
59
  image_output = gr.Image(label="Segmented Image", type="pil", height=400)
60
 
61
+ # Define the action triggered by the upload_image_input change
62
+ upload_image_input.change(
63
+ fn=update_mask,
64
+ inputs=[upload_image_input],
65
+ outputs=[image_output],
66
  show_progress=True,
67
  )
68