Spaces:
Running
on
Zero
Running
on
Zero
no shrooms
Browse files- sam2_mask.py +13 -8
sam2_mask.py
CHANGED
@@ -12,10 +12,9 @@ MODEL = "facebook/sam2.1-hiera-large"
|
|
12 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
13 |
|
14 |
@spaces.GPU()
|
|
|
15 |
def predict_masks(image, points):
|
16 |
"""Predict a single mask from the image based on selected points."""
|
17 |
-
global PREDICTOR
|
18 |
-
PREDICTOR = SAM2ImagePredictor.from_pretrained(MODEL, device=DEVICE)
|
19 |
image_np = np.array(image)
|
20 |
points_list = [[point["x"], point["y"]] for point in points]
|
21 |
input_labels = [1] * len(points_list)
|
@@ -37,8 +36,15 @@ def predict_masks(image, points):
|
|
37 |
else:
|
38 |
return image_np
|
39 |
|
|
|
40 |
def create_sam2_mask_interface():
|
41 |
"""Create the Gradio interface for SAM2 mask generation."""
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
with gr.Blocks() as sam2_mask_tab:
|
43 |
gr.Markdown("# Object Segmentation with SAM2")
|
44 |
gr.Markdown(
|
@@ -49,15 +55,14 @@ def create_sam2_mask_interface():
|
|
49 |
with gr.Row():
|
50 |
with gr.Column():
|
51 |
upload_image_input = ImagePrompter(show_label=False)
|
52 |
-
submit_button = gr.Button("Submit")
|
53 |
with gr.Column():
|
54 |
image_output = gr.Image(label="Segmented Image", type="pil", height=400)
|
55 |
|
56 |
-
# Define the action triggered by the
|
57 |
-
|
58 |
-
fn=
|
59 |
-
inputs=[upload_image_input
|
60 |
-
outputs=image_output,
|
61 |
show_progress=True,
|
62 |
)
|
63 |
|
|
|
12 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
13 |
|
14 |
@spaces.GPU()
|
15 |
+
|
16 |
def predict_masks(image, points):
|
17 |
"""Predict a single mask from the image based on selected points."""
|
|
|
|
|
18 |
image_np = np.array(image)
|
19 |
points_list = [[point["x"], point["y"]] for point in points]
|
20 |
input_labels = [1] * len(points_list)
|
|
|
36 |
else:
|
37 |
return image_np
|
38 |
|
39 |
+
|
40 |
def create_sam2_mask_interface():
|
41 |
"""Create the Gradio interface for SAM2 mask generation."""
|
42 |
+
def update_mask(prompts):
|
43 |
+
"""Update the mask based on the prompts."""
|
44 |
+
image = prompts["image"]
|
45 |
+
points = prompts["points"]
|
46 |
+
return predict_masks(image, points)
|
47 |
+
|
48 |
with gr.Blocks() as sam2_mask_tab:
|
49 |
gr.Markdown("# Object Segmentation with SAM2")
|
50 |
gr.Markdown(
|
|
|
55 |
with gr.Row():
|
56 |
with gr.Column():
|
57 |
upload_image_input = ImagePrompter(show_label=False)
|
|
|
58 |
with gr.Column():
|
59 |
image_output = gr.Image(label="Segmented Image", type="pil", height=400)
|
60 |
|
61 |
+
# Define the action triggered by the upload_image_input change
|
62 |
+
upload_image_input.change(
|
63 |
+
fn=update_mask,
|
64 |
+
inputs=[upload_image_input],
|
65 |
+
outputs=[image_output],
|
66 |
show_progress=True,
|
67 |
)
|
68 |
|