File size: 4,889 Bytes
06d49db
 
 
 
 
 
 
 
 
11fd0ad
06d49db
 
 
 
 
 
 
 
 
 
 
 
 
 
6fe4585
 
 
 
 
 
 
be73ac3
6fe4585
 
 
be73ac3
6fe4585
11fd0ad
6fe4585
be73ac3
06d49db
 
 
6fe4585
 
 
06d49db
6fe4585
06d49db
be73ac3
06d49db
 
 
 
6fe4585
 
 
4370bbd
 
6514616
4370bbd
 
6514616
6fe4585
 
 
38737c2
6fe4585
38737c2
 
6fe4585
 
be73ac3
121a9f3
 
 
be73ac3
 
6fe4585
 
38737c2
 
 
 
 
 
 
6fe4585
 
 
 
38737c2
 
6fe4585
 
 
 
be73ac3
6fe4585
 
 
be73ac3
6fe4585
 
06d49db
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import gradio as gr
import PIL
import numpy as np
from models.maskclip import MaskClip
from models.dino import DINO
import torchvision.transforms as T
import torch.nn.functional as F
from lposs import lposs, lposs_plus
import torch
import spaces

device = "cpu"
if torch.cuda.is_available():
    print("Using GPU")
    device = "cuda"
# elif torch.backends.mps.is_available():
#     device = "mps"

print(f"Using device: {device}")

maskclip = MaskClip().to(device)
dino = DINO().to(device)
to_torch_tensor = T.Compose([T.Resize(size=448, max_size=2048), T.ToTensor()])

# Default hyperparameter values
DEFAULT_SIGMA = 100
DEFAULT_ALPHA = 0.95
DEFAULT_K = 400
DEFAULT_WSIZE = 224
DEFAULT_GAMMA = 3.0
DEFAULT_TAU = 0.01
DEFAULT_R = 13

# Function to reset hyperparameters to default values
def reset_hyperparams():
    return DEFAULT_WSIZE, DEFAULT_K, DEFAULT_GAMMA, DEFAULT_ALPHA, DEFAULT_SIGMA, DEFAULT_TAU, DEFAULT_R

@spaces.GPU
def segment_image(img: PIL.Image.Image, classnames: str, use_lposs_plus: bool | None, 
                  winodw_size:int, k:int, gamma:float, alpha:float, sigma: float, tau:float, r:int) -> tuple[np.ndarray | PIL.Image.Image | str, list[tuple[np.ndarray | tuple[int, int, int, int], str]]]:
    img_tensor = to_torch_tensor(PIL.Image.fromarray(img)).unsqueeze(0).to(device)
    classnames = [c.strip() for c in classnames.split(",")]
    num_classes = len(classnames)

    winodw_size = (winodw_size, winodw_size)
    stride = (winodw_size[0] // 2, winodw_size[1] // 2)
    
    preds = lposs(maskclip, dino, img_tensor, classnames, window_size=winodw_size, window_stride=stride, sigma=1/sigma, lp_k_image=k, lp_gamma=gamma, lp_alpha=alpha)
    if use_lposs_plus:
        preds = lposs_plus(img_tensor, preds, tau=tau, alpha=alpha, r=r)
    preds = F.interpolate(preds, size=img.shape[:-1], mode="bilinear", align_corners=False)
    preds = F.softmax(preds * 100, dim=1).cpu().numpy()
    return (img, [(preds[0, i, :, :], classnames[i]) for i in range(num_classes)])

with gr.Blocks() as demo:
    gr.Markdown("# LPOSS: Label Propagation Over Patches and Pixels for Open-vocabulary Semantic Segmentation")
    gr.Markdown("""<div align='center' style='margin: 1em 0;'>
        <a href='http://arxiv.org/abs/2503.19777' target='_blank' style='margin-right: 2em; text-decoration: none; font-weight: bold;'>
            📄 arXiv
        </a>
        <a href='https://github.com/vladan-stojnic/LPOSS' target='_blank' style='text-decoration: none; font-weight: bold;'>
            💻 GitHub
        </a>
    </div>""")
    gr.Markdown("Upload an image and specify the objects you want to segment by listing their names separated by commas.")

    with gr.Accordion("Hyper-parameters", open=False):
        with gr.Column(scale=1):
            # with gr.Row():
            #     gr.Markdown("Hyper-parameters")
            with gr.Row():
                window_size = gr.Slider(minimum=112, maximum=448, value=DEFAULT_WSIZE, step=16, label="Window Size")
                k = gr.Slider(minimum=50, maximum=800, value=DEFAULT_K, step=50, label="k (LPOSS number of graph neighbors)")
                gamma = gr.Slider(minimum=0.0, maximum=10.0, value=DEFAULT_GAMMA, step=0.5, label="γ (LPOSS graph edge tuning)")
                sigma = gr.Slider(minimum=50, maximum=400, value=DEFAULT_SIGMA, step=10, label="σ (LPOSS spatial affinity tuning)")
                tau = gr.Slider(minimum=0.0, maximum=1.0, value=DEFAULT_TAU, step=0.01, label="τ (LPOSS+ appearance affinity tuning)")
                r = gr.Slider(minimum=3, maximum=15, value=DEFAULT_R, step=2, label="r (LPOSS+ kernel size)")
                alpha = gr.Slider(minimum=0.0, maximum=1.0, value=DEFAULT_ALPHA, step=0.05, label="α (amount of propagation)")
            with gr.Row():
                reset_btn = gr.Button("Reset to Default Values")

    with gr.Row():
        class_names = gr.Textbox(label="Class Names", info="Separate class names with commas")
        use_lposs_plus = gr.Checkbox(label="Use LPOSS+", info="Enable pixel-level refinement using LPOSS+")

    with gr.Row():
        segment_btn = gr.Button("Segment Image")
    
    with gr.Row():
        with gr.Column(scale=2):
            input_image = gr.Image(label="Input Image")
            # class_names = gr.Textbox(label="Class Names", info="Separate class names with commas")
            # use_lposs_plus = gr.Checkbox(label="Use LPOSS+", info="Enable pixel-level refinement using LPOSS+")
        
        with gr.Column(scale=3):
            output_image = gr.AnnotatedImage(label="Segmentation Results")
    
    reset_btn.click(fn=reset_hyperparams, outputs=[window_size, k, gamma, alpha, sigma, tau, r])
    
    segment_btn.click(
        fn=segment_image,
        inputs=[input_image, class_names, use_lposs_plus, window_size, k, gamma, alpha, sigma, tau, r],
        outputs=[output_image]
    )

demo.launch()