File size: 2,942 Bytes
57bcd95
72efb9f
 
 
 
 
 
 
 
 
 
 
57bcd95
 
72efb9f
57bcd95
 
 
 
 
 
 
 
 
 
 
 
268e766
57bcd95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import sys
import os
import cv2
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
import glob
import gradio as gr
from PIL import Image
from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
import logging
from huggingface_hub import hf_hub_download

token = os.environ['HUB_TOKEN']
loc =hf_hub_download(repo_id="JunchuanYu/files_for_segmentRS", filename="utils.py",repo_type="dataset",local_dir='.',token=token)
sys.path.append(loc)
from utils import *

with gr.Blocks(theme='gradio/soft') as demo:
    gr.Markdown(title)
    with gr.Accordion("Instructions For User 👉", open=False):
        gr.Markdown(description)
    x=gr.State(value=[])
    y=gr.State(value=[])
    label=gr.State(value=[])
    with gr.Row():
        with gr.Column(scale=13):  
            with gr.Row():
                with gr.Column():  
                    mode=gr.inputs.Radio(['Positive','Negative'], type="value",default='Positive',label='Types of sampling methods')
                with gr.Column():
                        clear_bn=gr.Button("Clear Selection")
                        interseg_button = gr.Button("Interactive Segment",variant='primary')
            with gr.Row():
                input_img = gr.Image(label="Input")
                gallery = gr.Image(label="Points")
                
            input_img.select(get_select_coords, [input_img, mode,x,y,label], [gallery,x,y,label])
            
            with gr.Row():
                output_img = gr.Image(label="Result")           
                mask_img = gr.Image(label="Mask")      
            with gr.Row():
                with gr.Column():
                    thresh = gr.Slider(minimum=0.8, maximum=1, value=0.90, step=0.01, interactive=True, label="Threshhold")
                with gr.Column():
                    points = gr.Slider(minimum=16, maximum=96, value=32, step=16, interactive=True, label="Points/Side")
            
        with gr.Column(scale=2,min_width=8):  
            example = gr.Examples(
            examples=[[s,0.9,32] for s in glob.glob('./images/*')],
            fn=auto_seg,
            inputs=[input_img,thresh,points],
            outputs=[output_img],
            cache_examples=False,examples_per_page=5)

    autoseg_button = gr.Button("Auto Segment",variant="primary")
    emptyBtn = gr.Button("Restart",variant="secondary")

    interseg_button.click(interactive_seg, inputs=[input_img,x,y,label], outputs=[output_img,mask_img])
    autoseg_button.click(auto_seg, inputs=[input_img,thresh,points], outputs=[mask_img])

    clear_bn.click(clear_point,outputs=[gallery,mode,x,y,label],show_progress=True)
    emptyBtn.click(reset_state,outputs=[input_img,gallery,output_img,mask_img,thresh,points,mode,x,y,label],show_progress=True,)   
        
    gr.Markdown(descriptionend)
if __name__ == "__main__":
    demo.launch(debug=False,show_api=False)