Improve interactivity by allowing direct bounding box drawing and tool selection (#1)
Browse files- Improve interactivity by allowing direct bounding box drawing and tool selection (2bd25a32ae2413702dec7566dab24761aa0e114f)
app.py
CHANGED
@@ -129,11 +129,9 @@ def viz_pred_mask(img,
|
|
129 |
else:
|
130 |
cv2.circle(out,(col, row), marker_size, (255,0,0), -1)
|
131 |
|
132 |
-
if bbox_coords
|
133 |
-
for
|
134 |
-
cv2.rectangle(out,
|
135 |
-
if len(bbox_coords) % 2 == 1:
|
136 |
-
cv2.circle(out, tuple(bbox_coords[-1]), marker_size, (255,165,0), -1)
|
137 |
|
138 |
return out.astype(np.uint8)
|
139 |
|
@@ -242,7 +240,8 @@ def refresh_predictions(predictor, input_img, output_img, click_coords, click_la
|
|
242 |
def get_select_coords(predictor, input_img, brush_label, bbox_label, best_mask, low_res_mask,
|
243 |
click_coords, click_labels, bbox_coords,
|
244 |
seperate_scribble_masks, last_scribble_mask, scribble_img, img_features,
|
245 |
-
output_img, binary_checkbox, multimask_mode, autopredict_checkbox,
|
|
|
246 |
"""
|
247 |
Record user click and update the prediction
|
248 |
"""
|
@@ -255,6 +254,16 @@ def get_select_coords(predictor, input_img, brush_label, bbox_label, best_mask,
|
|
255 |
else:
|
256 |
raise TypeError("Invalid brush label: {brush_label}")
|
257 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
258 |
# Only make new prediction if not waiting for additional bounding box click
|
259 |
if (len(bbox_coords) % 2 == 0) and autopredict_checkbox:
|
260 |
|
@@ -402,15 +411,20 @@ with gr.Blocks(theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg)) as de
|
|
402 |
)
|
403 |
|
404 |
with gr.Tab("Clicks/Boxes") as click_tab:
|
|
|
405 |
click_img = gr.Image(
|
406 |
label="Input",
|
407 |
type='numpy',
|
408 |
value=default_example,
|
|
|
|
|
409 |
show_download_button=True,
|
410 |
container=True,
|
411 |
height=display_height
|
412 |
)
|
413 |
with gr.Row():
|
|
|
|
|
414 |
undo_click_button = gr.Button("Undo Last Click")
|
415 |
clear_click_button = gr.Button("Clear Clicks/Boxes", variant="stop")
|
416 |
|
@@ -546,7 +560,8 @@ with gr.Blocks(theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg)) as de
|
|
546 |
input_img, brush_label, bbox_label, best_mask, low_res_mask, click_coords, click_labels, bbox_coords,
|
547 |
seperate_scribble_masks, last_scribble_mask, scribble_img, img_features,
|
548 |
output_img, binary_checkbox, multimask_mode, autopredict_checkbox
|
549 |
-
|
|
|
550 |
outputs=[click_img, scribble_img, output_img, best_mask, low_res_mask, img_features,
|
551 |
click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask],
|
552 |
api_name = "get_select_coords"
|
|
|
129 |
else:
|
130 |
cv2.circle(out,(col, row), marker_size, (255,0,0), -1)
|
131 |
|
132 |
+
if bbox_coords:
|
133 |
+
for bbox in bbox_coords:
|
134 |
+
cv2.rectangle(out, bbox[0], bbox[1], (255,165,0), 2)
|
|
|
|
|
135 |
|
136 |
return out.astype(np.uint8)
|
137 |
|
|
|
240 |
def get_select_coords(predictor, input_img, brush_label, bbox_label, best_mask, low_res_mask,
|
241 |
click_coords, click_labels, bbox_coords,
|
242 |
seperate_scribble_masks, last_scribble_mask, scribble_img, img_features,
|
243 |
+
output_img, binary_checkbox, multimask_mode, autopredict_checkbox,
|
244 |
+
tool_select, evt: gr.SelectData):
|
245 |
"""
|
246 |
Record user click and update the prediction
|
247 |
"""
|
|
|
254 |
else:
|
255 |
raise TypeError("Invalid brush label: {brush_label}")
|
256 |
|
257 |
+
if tool_select == "Rectangle" and evt.bounding_box:
|
258 |
+
# User drew a rectangle
|
259 |
+
bbox_coords.append((evt.bounding_box[0], evt.bounding_box[1]))
|
260 |
+
elif tool_select == "Point" and evt.index is not None:
|
261 |
+
# User clicked a point
|
262 |
+
click_coords.append(evt.index)
|
263 |
+
click_labels.append(1 if brush_label == 'Positive (green)' else 0)
|
264 |
+
else:
|
265 |
+
gr.Error("Invalid selection")
|
266 |
+
|
267 |
# Only make new prediction if not waiting for additional bounding box click
|
268 |
if (len(bbox_coords) % 2 == 0) and autopredict_checkbox:
|
269 |
|
|
|
411 |
)
|
412 |
|
413 |
with gr.Tab("Clicks/Boxes") as click_tab:
|
414 |
+
# Update click_img to be interactive and use the selected tool
|
415 |
click_img = gr.Image(
|
416 |
label="Input",
|
417 |
type='numpy',
|
418 |
value=default_example,
|
419 |
+
tool="select", # Use 'select' to capture both points and rectangles
|
420 |
+
interactive=True,
|
421 |
show_download_button=True,
|
422 |
container=True,
|
423 |
height=display_height
|
424 |
)
|
425 |
with gr.Row():
|
426 |
+
# Add a tool selection radio button
|
427 |
+
tool_select = gr.Radio(["Point", "Rectangle"], label="Tool", value="Point")
|
428 |
undo_click_button = gr.Button("Undo Last Click")
|
429 |
clear_click_button = gr.Button("Clear Clicks/Boxes", variant="stop")
|
430 |
|
|
|
560 |
input_img, brush_label, bbox_label, best_mask, low_res_mask, click_coords, click_labels, bbox_coords,
|
561 |
seperate_scribble_masks, last_scribble_mask, scribble_img, img_features,
|
562 |
output_img, binary_checkbox, multimask_mode, autopredict_checkbox
|
563 |
+
tool_select # Add tool_select here
|
564 |
+
],
|
565 |
outputs=[click_img, scribble_img, output_img, best_mask, low_res_mask, img_features,
|
566 |
click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask],
|
567 |
api_name = "get_select_coords"
|