|
from ast import Interactive |
|
from xml.sax.xmlreader import InputSource |
|
import gradio as gr |
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
import os |
|
import cv2 |
|
import pathlib |
|
import math |
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
from predictor import Predictor |
|
|
|
display_height = 600 |
|
H = 256 |
|
W = 256 |
|
|
|
test_example_dir = pathlib.Path("./test_examples") |
|
test_examples = [str(test_example_dir / x) for x in sorted(os.listdir(test_example_dir))] |
|
|
|
default_example = test_examples[0] |
|
exp_dir = pathlib.Path('./checkpoints') |
|
default_model = 'ScribblePrompt-Unet' |
|
|
|
model_dict = { |
|
'ScribblePrompt-Unet': 'ScribblePrompt_unet_v1_nf192_res128.pt' |
|
} |
|
|
|
|
|
|
|
|
|
|
|
def load_model(exp_key: str = default_model): |
|
fpath = exp_dir / model_dict.get(exp_key) |
|
exp = Predictor(fpath) |
|
return exp, None |
|
|
|
|
|
|
|
|
|
|
|
def _get_overlay(img, lay, const_color="l_blue"): |
|
""" |
|
Helper function for preparing overlay |
|
""" |
|
assert lay.ndim==2, "Overlay must be 2D, got shape: " + str(lay.shape) |
|
|
|
if img.ndim == 2: |
|
img = np.repeat(img[...,None], 3, axis=-1) |
|
|
|
assert img.ndim==3, "Image must be 3D, got shape: " + str(img.shape) |
|
|
|
if const_color == "blue": |
|
const_color = 255*np.array([0, 0, 1]) |
|
elif const_color == "green": |
|
const_color = 255*np.array([0, 1, 0]) |
|
elif const_color == "red": |
|
const_color = 255*np.array([1, 0, 0]) |
|
elif const_color == "l_blue": |
|
const_color = np.array([31, 119, 180]) |
|
elif const_color == "orange": |
|
const_color = np.array([255, 127, 14]) |
|
else: |
|
raise NotImplementedError |
|
|
|
x,y = np.nonzero(lay) |
|
for i in range(img.shape[-1]): |
|
img[x,y,i] = const_color[i] |
|
|
|
return img |
|
|
|
def image_overlay(img, mask=None, scribbles=None, contour=False, alpha=0.5): |
|
""" |
|
Overlay the ground truth mask and scribbles on the image if provided |
|
""" |
|
assert img.ndim == 2, "Image must be 2D, got shape: " + str(img.shape) |
|
output = np.repeat(img[...,None], 3, axis=-1) |
|
|
|
if mask is not None: |
|
|
|
assert mask.ndim == 2, "Mask must be 2D, got shape: " + str(mask.shape) |
|
|
|
if contour: |
|
contours = cv2.findContours((mask[...,None]>0.5).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) |
|
cv2.drawContours(output, contours[0], -1, (0, 255, 0), 2) |
|
else: |
|
mask_overlay = _get_overlay(img, mask) |
|
mask2 = 0.5*np.repeat(mask[...,None], 3, axis=-1) |
|
output = cv2.convertScaleAbs(mask_overlay * mask2 + output * (1 - mask2)) |
|
|
|
if scribbles is not None: |
|
pos_scribble_overlay = _get_overlay(output, scribbles[0,...], const_color="green") |
|
cv2.addWeighted(pos_scribble_overlay, alpha, output, 1 - alpha, 0, output) |
|
neg_scribble_overlay = _get_overlay(output, scribbles[1,...], const_color="red") |
|
cv2.addWeighted(neg_scribble_overlay, alpha, output, 1 - alpha, 0, output) |
|
|
|
return output |
|
|
|
|
|
def viz_pred_mask(img, |
|
mask=None, |
|
point_coords=None, |
|
point_labels=None, |
|
bbox_coords=None, |
|
seperate_scribble_masks=None, |
|
binary=True): |
|
""" |
|
Visualize image with clicks, scribbles, predicted mask overlaid |
|
""" |
|
assert isinstance(img, np.ndarray), "Image must be numpy array, got type: " + str(type(img)) |
|
if mask is not None: |
|
if isinstance(mask, torch.Tensor): |
|
mask = mask.cpu().numpy() |
|
|
|
if binary and mask is not None: |
|
mask = 1*(mask > 0.5) |
|
|
|
out = image_overlay(img, mask=mask, scribbles=seperate_scribble_masks) |
|
|
|
H,W = img.shape[:2] |
|
marker_size = min(H,W)//100 |
|
|
|
if point_coords is not None: |
|
for i,(col,row) in enumerate(point_coords): |
|
if point_labels[i] == 1: |
|
cv2.circle(out,(col, row), marker_size, (0,255,0), -1) |
|
else: |
|
cv2.circle(out,(col, row), marker_size, (255,0,0), -1) |
|
|
|
if bbox_coords: |
|
for bbox in bbox_coords: |
|
cv2.rectangle(out, bbox[0], bbox[1], (255,165,0), 2) |
|
|
|
return out.astype(np.uint8) |
|
|
|
|
|
|
|
|
|
|
|
def get_scribbles(seperate_scribble_masks, last_scribble_mask, scribble_img): |
|
""" |
|
Record scribbles |
|
""" |
|
assert isinstance(seperate_scribble_masks, np.ndarray), "seperate_scribble_masks must be numpy array, got type: " + str(type(seperate_scribble_masks)) |
|
|
|
if scribble_img is not None: |
|
|
|
|
|
color_mask = scribble_img.get('layers')[0] |
|
|
|
positive_scribbles = 1.0*(color_mask[...,1] > 128) |
|
negative_scribbles = 1.0*(color_mask[...,0] > 128) |
|
|
|
seperate_scribble_masks = np.stack([positive_scribbles, negative_scribbles], axis=0) |
|
last_scribble_mask = None |
|
|
|
return seperate_scribble_masks, last_scribble_mask |
|
|
|
def get_predictions(predictor, input_img, click_coords, click_labels, bbox_coords, seperate_scribble_masks, |
|
low_res_mask, img_features, multimask_mode): |
|
""" |
|
Make predictions |
|
""" |
|
box = None |
|
if len(bbox_coords) == 1: |
|
gr.Error("Please click a second time to define the bounding box") |
|
box = None |
|
elif len(bbox_coords) == 2: |
|
box = torch.Tensor(bbox_coords).flatten()[None,None,...].int().to(device) |
|
|
|
if seperate_scribble_masks is not None: |
|
scribble = torch.from_numpy(seperate_scribble_masks)[None,...].to(device) |
|
else: |
|
scribble = None |
|
|
|
prompts = dict( |
|
img=torch.from_numpy(input_img)[None,None,...].to(device)/255, |
|
point_coords=torch.Tensor([click_coords]).int().to(device) if len(click_coords)>0 else None, |
|
point_labels=torch.Tensor([click_labels]).int().to(device) if len(click_labels)>0 else None, |
|
scribble=scribble, |
|
mask_input=low_res_mask.to(device) if low_res_mask is not None else None, |
|
box=box, |
|
) |
|
|
|
mask, img_features, low_res_mask = predictor.predict(prompts, img_features, multimask_mode=multimask_mode) |
|
|
|
return mask, img_features, low_res_mask |
|
|
|
def refresh_predictions(predictor, input_img, output_img, click_coords, click_labels, bbox_coords, brush_label, |
|
scribble_img, seperate_scribble_masks, last_scribble_mask, |
|
best_mask, low_res_mask, img_features, binary_checkbox, multimask_mode): |
|
|
|
|
|
seperate_scribble_masks, last_scribble_mask = get_scribbles( |
|
seperate_scribble_masks, last_scribble_mask, scribble_img |
|
) |
|
|
|
|
|
best_mask, img_features, low_res_mask = get_predictions( |
|
predictor, input_img, click_coords, click_labels, bbox_coords, seperate_scribble_masks, low_res_mask, img_features, multimask_mode |
|
) |
|
|
|
|
|
mask_to_viz = best_mask.numpy() |
|
click_input_viz = viz_pred_mask(input_img, mask_to_viz, click_coords, click_labels, bbox_coords, seperate_scribble_masks, binary_checkbox) |
|
|
|
empty_channel = np.zeros(input_img.shape[:2]).astype(np.uint8) |
|
full_channel = 255*np.ones(input_img.shape[:2]).astype(np.uint8) |
|
gray_mask = (255*mask_to_viz).astype(np.uint8) |
|
|
|
bg = viz_pred_mask(input_img, mask_to_viz, click_coords, click_labels, bbox_coords, None, binary_checkbox) |
|
old_scribbles = scribble_img.get('layers')[0] |
|
|
|
scribble_mask = 255*(old_scribbles > 0).any(-1) |
|
|
|
scribble_input_viz = { |
|
"background": np.stack([bg[...,i] for i in range(3)]+[full_channel], axis=-1), |
|
["layers"][0]: [np.stack([ |
|
(255*seperate_scribble_masks[1]).astype(np.uint8), |
|
(255*seperate_scribble_masks[0]).astype(np.uint8), |
|
empty_channel, |
|
scribble_mask |
|
], axis=-1)], |
|
"composite": np.stack([click_input_viz[...,i] for i in range(3)]+[empty_channel], axis=-1), |
|
} |
|
|
|
mask_img = 255*(mask_to_viz[...,None].repeat(axis=2, repeats=3)>0.5) if binary_checkbox else mask_to_viz[...,None].repeat(axis=2, repeats=3) |
|
|
|
out_viz = [ |
|
viz_pred_mask(input_img, mask_to_viz, point_coords=None, point_labels=None, bbox_coords=None, seperate_scribble_masks=None, binary=binary_checkbox), |
|
input_img, |
|
mask_img, |
|
] |
|
|
|
return click_input_viz, scribble_input_viz, out_viz, best_mask, low_res_mask, img_features, seperate_scribble_masks, last_scribble_mask |
|
|
|
|
|
def get_select_coords(predictor, input_img, brush_label, bbox_label, best_mask, low_res_mask, |
|
click_coords, click_labels, bbox_coords, |
|
seperate_scribble_masks, last_scribble_mask, scribble_img, img_features, |
|
output_img, binary_checkbox, multimask_mode, autopredict_checkbox, |
|
tool_select, evt: gr.SelectData): |
|
""" |
|
Record user click and update the prediction |
|
""" |
|
|
|
if bbox_label: |
|
bbox_coords.append(evt.index) |
|
elif brush_label in ['Positive (green)', 'Negative (red)']: |
|
click_coords.append(evt.index) |
|
click_labels.append(1 if brush_label=='Positive (green)' else 0) |
|
else: |
|
raise TypeError("Invalid brush label: {brush_label}") |
|
|
|
if tool_select == "Rectangle" and evt.bounding_box: |
|
|
|
bbox_coords.append((evt.bounding_box[0], evt.bounding_box[1])) |
|
elif tool_select == "Point" and evt.index is not None: |
|
|
|
click_coords.append(evt.index) |
|
click_labels.append(1 if brush_label == 'Positive (green)' else 0) |
|
else: |
|
gr.Error("Invalid selection") |
|
|
|
|
|
if (len(bbox_coords) % 2 == 0) and autopredict_checkbox: |
|
|
|
click_input_viz, scribble_input_viz, output_viz, best_mask, low_res_mask, img_features, seperate_scribble_masks, last_scribble_mask = refresh_predictions( |
|
predictor, input_img, output_img, click_coords, click_labels, bbox_coords, brush_label, |
|
scribble_img, seperate_scribble_masks, last_scribble_mask, |
|
best_mask, low_res_mask, img_features, binary_checkbox, multimask_mode |
|
) |
|
return click_input_viz, scribble_input_viz, output_viz, best_mask, low_res_mask, img_features, click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask |
|
|
|
else: |
|
click_input_viz = viz_pred_mask( |
|
input_img, best_mask, click_coords, click_labels, bbox_coords, seperate_scribble_masks, binary_checkbox |
|
) |
|
scribble_input_viz = viz_pred_mask( |
|
input_img, best_mask, click_coords, click_labels, bbox_coords, None, binary_checkbox |
|
) |
|
|
|
return click_input_viz, scribble_input_viz, output_img, best_mask, low_res_mask, img_features, click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask |
|
|
|
|
|
def undo_click(predictor, input_img, brush_label, bbox_label, best_mask, low_res_mask, click_coords, click_labels, bbox_coords, |
|
seperate_scribble_masks, last_scribble_mask, scribble_img, img_features, |
|
output_img, binary_checkbox, multimask_mode, autopredict_checkbox): |
|
""" |
|
Remove last click and then update the prediction |
|
""" |
|
if bbox_label: |
|
if len(bbox_coords) > 0: |
|
bbox_coords.pop() |
|
elif brush_label in ['Positive (green)', 'Negative (red)']: |
|
if len(click_coords) > 0: |
|
click_coords.pop() |
|
click_labels.pop() |
|
else: |
|
raise TypeError("Invalid brush label: {brush_label}") |
|
|
|
|
|
if (len(bbox_coords)==0 or len(bbox_coords)==2) and autopredict_checkbox: |
|
|
|
click_input_viz, scribble_input_viz, output_viz, best_mask, low_res_mask, img_features, seperate_scribble_masks, last_scribble_mask = refresh_predictions( |
|
predictor, input_img, output_img, click_coords, click_labels, bbox_coords, brush_label, |
|
scribble_img, seperate_scribble_masks, last_scribble_mask, |
|
best_mask, low_res_mask, img_features, binary_checkbox, multimask_mode |
|
) |
|
return click_input_viz, scribble_input_viz, output_viz, best_mask, low_res_mask, img_features, click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask |
|
|
|
else: |
|
click_input_viz = viz_pred_mask( |
|
input_img, best_mask, click_coords, click_labels, bbox_coords, seperate_scribble_masks, binary_checkbox |
|
) |
|
scribble_input_viz = viz_pred_mask( |
|
input_img, best_mask, click_coords, click_labels, bbox_coords, None, binary_checkbox |
|
) |
|
|
|
|
|
return click_input_viz, scribble_input_viz, output_img, best_mask, low_res_mask, img_features, click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg)) as demo: |
|
|
|
|
|
seperate_scribble_masks = gr.State(np.zeros((2, H, W), dtype=np.float32)) |
|
last_scribble_mask = gr.State(np.zeros((H, W), dtype=np.float32)) |
|
|
|
click_coords = gr.State([]) |
|
click_labels = gr.State([]) |
|
bbox_coords = gr.State([]) |
|
|
|
|
|
predictor = gr.State(load_model()[0]) |
|
img_features = gr.State(None) |
|
best_mask = gr.State(None) |
|
low_res_mask = gr.State(None) |
|
|
|
gr.HTML("""\ |
|
<h1 style="text-align: center; font-size: 28pt;">ScribblePrompt: Fast and Flexible Interactive Segmention for Any Biomedical Image</h1> |
|
<p style="text-align: center; font-size: large;"> |
|
<b>ScribblePrompt</b> is an interactive segmentation tool designed to help users segment <b>new</b> structures in medical images using scribbles, clicks <b>and</b> bounding boxes. |
|
[<a href="https://arxiv.org/abs/2312.07381">paper</a> | <a href="https://scribbleprompt.csail.mit.edu">website</a> | <a href="https://github.com/halleewong/ScribblePrompt">code</a>] |
|
</p> |
|
""") |
|
|
|
with gr.Accordion("Open for instructions!", open=False): |
|
gr.Markdown( |
|
""" |
|
* Select an input image from the examples below or upload your own image through the <b>'Input Image'</b> tab. |
|
* Use the <b>'Scribbles'</b> tab to draw <span style='color:green'>positive</span> or <span style='color:red'>negative</span> scribbles. |
|
- Use the buttons in the top right hand corner of the canvas to undo or adjust the brush size |
|
- Note: the app cannot detect new scribbles drawn on top of previous scribbles in a different color. Please undo/erase the scribble before drawing on the same pixel in a different color. |
|
* Use the <b>'Clicks/Boxes'</b> tab to draw <span style='color:green'>positive</span> or <span style='color:red'>negative</span> clicks and <span style='color:orange'>bounding boxes</span> by placing two clicks. |
|
* The <b>'Output'</b> tab will show the model's prediction based on your current inputs and the previous prediction. |
|
* The <b>'Clear Input Mask'</b> button will clear the latest prediction (which is used as an input to the model). |
|
* The <b>'Clear All Inputs'</b> button will clear all inputs (including scribbles, clicks, bounding boxes, and the last prediction). |
|
""" |
|
) |
|
|
|
|
|
|
|
|
|
with gr.Row(): |
|
model_dropdown = gr.Dropdown( |
|
label="Model", |
|
choices = list(model_dict.keys()), |
|
value=default_model, |
|
multiselect=False, |
|
interactive=False, |
|
visible=False |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
brush_label = gr.Radio(["Positive (green)", "Negative (red)"], |
|
value="Positive (green)", label="Scribble/Click Label") |
|
bbox_label = gr.Checkbox(value=False, label="Bounding Box (2 clicks)") |
|
with gr.Column(scale=1): |
|
|
|
binary_checkbox = gr.Checkbox(value=True, label="Show binary masks", visible=False) |
|
autopredict_checkbox = gr.Checkbox(value=True, label="Auto-update prediction on clicks") |
|
with gr.Accordion("Troubleshooting tips", open=False): |
|
gr.Markdown("<span style='color:orange'>If you encounter an <span style='color:orange'>error</span> try clicking 'Clear All Inputs'.") |
|
multimask_mode = gr.Checkbox(value=True, label="Multi-mask mode", visible=False) |
|
|
|
with gr.Row(): |
|
|
|
green_brush = gr.Brush(colors=["#00FF00"], color_mode="fixed", default_size=3) |
|
red_brush = gr.Brush(colors=["#FF0000"], color_mode="fixed", default_size=3) |
|
|
|
with gr.Column(scale=1): |
|
with gr.Tab("Scribbles"): |
|
scribble_img = gr.ImageEditor( |
|
label="Input", |
|
image_mode="RGB", |
|
brush=green_brush, |
|
type='numpy', |
|
value=default_example, |
|
transforms=(), |
|
sources=(), |
|
container=True, |
|
show_download_button=True, |
|
height=display_height+60 |
|
) |
|
|
|
with gr.Tab("Clicks/Boxes") as click_tab: |
|
|
|
click_img = gr.ImageEditor( |
|
label="Input", |
|
type='numpy', |
|
value=default_example, |
|
tool="select", |
|
interactive=True, |
|
show_download_button=True, |
|
container=True, |
|
height=display_height |
|
) |
|
with gr.Row(): |
|
|
|
tool_select = gr.Radio(["Point", "Rectangle"], label="Tool", value="Point") |
|
undo_click_button = gr.Button("Undo Last Click") |
|
clear_click_button = gr.Button("Clear Clicks/Boxes", variant="stop") |
|
|
|
with gr.Tab("Input Image"): |
|
input_img = gr.Image( |
|
label="Input", |
|
image_mode="L", |
|
value=default_example, |
|
show_download_button=True, |
|
container=True, |
|
height=display_height |
|
) |
|
gr.Markdown("To upload your own image: click the `x` in the top right corner to clear the current image, then drag & drop") |
|
|
|
with gr.Column(scale=1): |
|
with gr.Tab("Output"): |
|
output_img = gr.Gallery( |
|
label='Output', |
|
columns=1, |
|
elem_id="gallery", |
|
preview=True, |
|
object_fit="scale-down", |
|
height=display_height+60, |
|
container=True |
|
) |
|
|
|
submit_button = gr.Button("Refresh Prediction", variant='primary') |
|
clear_all_button = gr.ClearButton([scribble_img], value="Clear All Inputs", variant="stop") |
|
clear_mask_button = gr.Button("Clear Input Mask") |
|
|
|
|
|
|
|
|
|
|
|
model_dropdown.change(fn=load_model, |
|
inputs=[model_dropdown], |
|
outputs=[predictor, img_features] |
|
) |
|
|
|
|
|
|
|
|
|
|
|
gr.Examples(examples=test_examples, |
|
inputs=[input_img], |
|
examples_per_page=12, |
|
label='Examples from datasets unseen during training' |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def clear_click_history(input_img): |
|
return input_img, input_img, [], [], [], None, None |
|
|
|
clear_click_button.click(clear_click_history, |
|
inputs=[input_img], |
|
outputs=[click_img, scribble_img, click_coords, click_labels, bbox_coords, best_mask, low_res_mask]) |
|
|
|
|
|
def clear_all_history(input_img): |
|
if input_img is not None: |
|
input_shape = input_img.shape[:2] |
|
else: |
|
input_shape = (H, W) |
|
return input_img, input_img, [], [], [], [], np.zeros((2,)+input_shape, dtype=np.float32), np.zeros(input_shape, dtype=np.float32), None, None, None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_img.change(clear_all_history, |
|
inputs=[input_img], |
|
outputs=[click_img, scribble_img, |
|
output_img, click_coords, click_labels, bbox_coords, |
|
seperate_scribble_masks, last_scribble_mask, |
|
best_mask, low_res_mask, img_features |
|
]) |
|
|
|
clear_all_button.click(clear_all_history, |
|
inputs=[input_img], |
|
outputs=[click_img, scribble_img, |
|
output_img, click_coords, click_labels, bbox_coords, |
|
seperate_scribble_masks, last_scribble_mask, |
|
best_mask, low_res_mask, img_features |
|
]) |
|
|
|
|
|
def clear_best_mask(input_img, click_coords, click_labels, bbox_coords, seperate_scribble_masks): |
|
|
|
click_input_viz = viz_pred_mask( |
|
input_img, None, click_coords, click_labels, bbox_coords, seperate_scribble_masks |
|
) |
|
scribble_input_viz = viz_pred_mask( |
|
input_img, None, click_coords, click_labels, bbox_coords, None |
|
) |
|
|
|
return None, None, click_input_viz, scribble_input_viz |
|
|
|
clear_mask_button.click( |
|
clear_best_mask, |
|
inputs=[input_img, click_coords, click_labels, bbox_coords, seperate_scribble_masks], |
|
outputs=[best_mask, low_res_mask, click_img, scribble_img], |
|
) |
|
|
|
|
|
|
|
|
|
|
|
click_img.select(get_select_coords, |
|
inputs=[ |
|
predictor, |
|
input_img, brush_label, bbox_label, best_mask, low_res_mask, click_coords, click_labels, bbox_coords, |
|
seperate_scribble_masks, last_scribble_mask, scribble_img, img_features, |
|
output_img, binary_checkbox, multimask_mode, autopredict_checkbox, |
|
tool_select |
|
], |
|
outputs=[click_img, scribble_img, output_img, best_mask, low_res_mask, img_features, |
|
click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask], |
|
api_name = "get_select_coords" |
|
) |
|
|
|
submit_button.click(fn=refresh_predictions, |
|
inputs=[ |
|
predictor, input_img, output_img, click_coords, click_labels, bbox_coords, brush_label, |
|
scribble_img, seperate_scribble_masks, last_scribble_mask, |
|
best_mask, low_res_mask, img_features, binary_checkbox, multimask_mode |
|
], |
|
outputs=[click_img, scribble_img, output_img, best_mask, low_res_mask, img_features, |
|
seperate_scribble_masks, last_scribble_mask], |
|
api_name="refresh_predictions" |
|
) |
|
|
|
undo_click_button.click(fn=undo_click, |
|
inputs=[ |
|
predictor, |
|
input_img, brush_label, bbox_label, best_mask, low_res_mask, click_coords, |
|
click_labels, bbox_coords, |
|
seperate_scribble_masks, last_scribble_mask, scribble_img, img_features, |
|
output_img, binary_checkbox, multimask_mode, autopredict_checkbox |
|
], |
|
outputs=[click_img, scribble_img, output_img, best_mask, low_res_mask, img_features, |
|
click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask], |
|
api_name="undo_click" |
|
) |
|
|
|
def update_click_img(input_img, click_coords, click_labels, bbox_coords, seperate_scribble_masks, binary_checkbox, |
|
last_scribble_mask, scribble_img, brush_label, best_mask): |
|
""" |
|
Draw scribbles in the click canvas |
|
""" |
|
seperate_scribble_masks, last_scribble_mask = get_scribbles( |
|
seperate_scribble_masks, last_scribble_mask, scribble_img |
|
) |
|
click_input_viz = viz_pred_mask( |
|
input_img, best_mask, click_coords, click_labels, bbox_coords, seperate_scribble_masks, binary_checkbox |
|
) |
|
return click_input_viz, seperate_scribble_masks, last_scribble_mask |
|
|
|
click_tab.select(fn=update_click_img, |
|
inputs=[input_img, click_coords, click_labels, bbox_coords, seperate_scribble_masks, |
|
binary_checkbox, last_scribble_mask, scribble_img, brush_label, best_mask], |
|
outputs=[click_img, seperate_scribble_masks, last_scribble_mask], |
|
api_name="update_click_img" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
def change_brush_color(seperate_scribble_masks, last_scribble_mask, scribble_img, label): |
|
""" |
|
Recorn new scribbles when changing brush color |
|
""" |
|
if label == "Negative (red)": |
|
brush_update = gr.update(brush=red_brush) |
|
elif label == "Positive (green)": |
|
brush_update = gr.update(brush=green_brush) |
|
else: |
|
raise TypeError("Invalid brush color") |
|
|
|
return seperate_scribble_masks, last_scribble_mask, brush_update |
|
|
|
brush_label.change(fn=change_brush_color, |
|
inputs=[seperate_scribble_masks, last_scribble_mask, scribble_img, brush_label], |
|
outputs=[seperate_scribble_masks, last_scribble_mask, scribble_img], |
|
api_name="change_brush_color" |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
demo.queue(api_open=False).launch(show_api=False) |
|
|