Amordia commited on
Commit
2bd25a3
·
verified ·
1 Parent(s): 58ab35a

Improve interactivity by allowing direct bounding box drawing and tool selection

Browse files

- Updated click_img component to support interactive selection of points and rectangles
- Added tool_select radio button to choose between 'Point' and 'Rectangle' tools
- Modified get_select_coords function to handle rectangle selections
- Updated visualization to correctly display bounding boxes

Files changed (1) hide show
  1. app.py +22 -7
app.py CHANGED
@@ -129,11 +129,9 @@ def viz_pred_mask(img,
129
  else:
130
  cv2.circle(out,(col, row), marker_size, (255,0,0), -1)
131
 
132
- if bbox_coords is not None:
133
- for i in range(len(bbox_coords)//2):
134
- cv2.rectangle(out, bbox_coords[2*i], bbox_coords[2*i+1], (255,165,0), marker_size)
135
- if len(bbox_coords) % 2 == 1:
136
- cv2.circle(out, tuple(bbox_coords[-1]), marker_size, (255,165,0), -1)
137
 
138
  return out.astype(np.uint8)
139
 
@@ -242,7 +240,8 @@ def refresh_predictions(predictor, input_img, output_img, click_coords, click_la
242
  def get_select_coords(predictor, input_img, brush_label, bbox_label, best_mask, low_res_mask,
243
  click_coords, click_labels, bbox_coords,
244
  seperate_scribble_masks, last_scribble_mask, scribble_img, img_features,
245
- output_img, binary_checkbox, multimask_mode, autopredict_checkbox, evt: gr.SelectData):
 
246
  """
247
  Record user click and update the prediction
248
  """
@@ -255,6 +254,16 @@ def get_select_coords(predictor, input_img, brush_label, bbox_label, best_mask,
255
  else:
256
  raise TypeError("Invalid brush label: {brush_label}")
257
 
 
 
 
 
 
 
 
 
 
 
258
  # Only make new prediction if not waiting for additional bounding box click
259
  if (len(bbox_coords) % 2 == 0) and autopredict_checkbox:
260
 
@@ -402,15 +411,20 @@ with gr.Blocks(theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg)) as de
402
  )
403
 
404
  with gr.Tab("Clicks/Boxes") as click_tab:
 
405
  click_img = gr.Image(
406
  label="Input",
407
  type='numpy',
408
  value=default_example,
 
 
409
  show_download_button=True,
410
  container=True,
411
  height=display_height
412
  )
413
  with gr.Row():
 
 
414
  undo_click_button = gr.Button("Undo Last Click")
415
  clear_click_button = gr.Button("Clear Clicks/Boxes", variant="stop")
416
 
@@ -546,7 +560,8 @@ with gr.Blocks(theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg)) as de
546
  input_img, brush_label, bbox_label, best_mask, low_res_mask, click_coords, click_labels, bbox_coords,
547
  seperate_scribble_masks, last_scribble_mask, scribble_img, img_features,
548
  output_img, binary_checkbox, multimask_mode, autopredict_checkbox
549
- ],
 
550
  outputs=[click_img, scribble_img, output_img, best_mask, low_res_mask, img_features,
551
  click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask],
552
  api_name = "get_select_coords"
 
129
  else:
130
  cv2.circle(out,(col, row), marker_size, (255,0,0), -1)
131
 
132
+ if bbox_coords:
133
+ for bbox in bbox_coords:
134
+ cv2.rectangle(out, bbox[0], bbox[1], (255,165,0), 2)
 
 
135
 
136
  return out.astype(np.uint8)
137
 
 
240
  def get_select_coords(predictor, input_img, brush_label, bbox_label, best_mask, low_res_mask,
241
  click_coords, click_labels, bbox_coords,
242
  seperate_scribble_masks, last_scribble_mask, scribble_img, img_features,
243
+ output_img, binary_checkbox, multimask_mode, autopredict_checkbox,
244
+ tool_select, evt: gr.SelectData):
245
  """
246
  Record user click and update the prediction
247
  """
 
254
  else:
255
  raise TypeError("Invalid brush label: {brush_label}")
256
 
257
+ if tool_select == "Rectangle" and evt.bounding_box:
258
+ # User drew a rectangle
259
+ bbox_coords.append((evt.bounding_box[0], evt.bounding_box[1]))
260
+ elif tool_select == "Point" and evt.index is not None:
261
+ # User clicked a point
262
+ click_coords.append(evt.index)
263
+ click_labels.append(1 if brush_label == 'Positive (green)' else 0)
264
+ else:
265
+ gr.Error("Invalid selection")
266
+
267
  # Only make new prediction if not waiting for additional bounding box click
268
  if (len(bbox_coords) % 2 == 0) and autopredict_checkbox:
269
 
 
411
  )
412
 
413
  with gr.Tab("Clicks/Boxes") as click_tab:
414
+ # Update click_img to be interactive and use the selected tool
415
  click_img = gr.Image(
416
  label="Input",
417
  type='numpy',
418
  value=default_example,
419
+ tool="select", # Use 'select' to capture both points and rectangles
420
+ interactive=True,
421
  show_download_button=True,
422
  container=True,
423
  height=display_height
424
  )
425
  with gr.Row():
426
+ # Add a tool selection radio button
427
+ tool_select = gr.Radio(["Point", "Rectangle"], label="Tool", value="Point")
428
  undo_click_button = gr.Button("Undo Last Click")
429
  clear_click_button = gr.Button("Clear Clicks/Boxes", variant="stop")
430
 
 
560
  input_img, brush_label, bbox_label, best_mask, low_res_mask, click_coords, click_labels, bbox_coords,
561
  seperate_scribble_masks, last_scribble_mask, scribble_img, img_features,
562
  output_img, binary_checkbox, multimask_mode, autopredict_checkbox
563
+ tool_select # Add tool_select here
564
+ ],
565
  outputs=[click_img, scribble_img, output_img, best_mask, low_res_mask, img_features,
566
  click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask],
567
  api_name = "get_select_coords"