put hyper-parameters
Browse files
app.py
CHANGED
@@ -22,39 +22,79 @@ maskclip = MaskClip().to(device)
|
|
22 |
dino = DINO().to(device)
|
23 |
to_torch_tensor = T.Compose([T.Resize(size=448, max_size=2048), T.ToTensor()])
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
@spaces.GPU
|
26 |
-
def segment_image(img: PIL.Image.Image, classnames: str, use_lposs_plus: bool | None
|
|
|
27 |
img_tensor = to_torch_tensor(PIL.Image.fromarray(img)).unsqueeze(0).to(device)
|
28 |
classnames = [c.strip() for c in classnames.split(",")]
|
29 |
num_classes = len(classnames)
|
|
|
|
|
|
|
30 |
|
31 |
-
preds = lposs(maskclip, dino, img_tensor, classnames)
|
32 |
if use_lposs_plus:
|
33 |
-
preds = lposs_plus(img_tensor, preds)
|
34 |
preds = F.interpolate(preds, size=img.shape[:-1], mode="bilinear", align_corners=False)
|
35 |
preds = F.softmax(preds * 100, dim=1).cpu().numpy()
|
36 |
return (img, [(preds[0, i, :, :], classnames[i]) for i in range(num_classes)])
|
37 |
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
gr.Image(label="Input Image"),
|
42 |
-
gr.Textbox(label="Class Names", info="Separate class names with commas"),
|
43 |
-
gr.Checkbox(label="Use LPOSS+", info="Enable pixel-level refinement using LPOSS+")
|
44 |
-
],
|
45 |
-
outputs=[
|
46 |
-
gr.AnnotatedImage(label="Segmentation Results")
|
47 |
-
],
|
48 |
-
title="LPOSS: Label Propagation Over Patches and Pixels for Open-vocabulary Semantic Segmentation",
|
49 |
-
article="""<div align='center' style='margin: 1em 0;'>
|
50 |
<a href='http://arxiv.org/abs/2503.19777' target='_blank' style='margin-right: 2em; text-decoration: none; font-weight: bold;'>
|
51 |
π arXiv
|
52 |
</a>
|
53 |
<a href='https://github.com/vladan-stojnic/LPOSS' target='_blank' style='text-decoration: none; font-weight: bold;'>
|
54 |
π» GitHub
|
55 |
</a>
|
56 |
-
</div>"""
|
57 |
-
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
demo.launch()
|
|
|
22 |
dino = DINO().to(device)
|
23 |
to_torch_tensor = T.Compose([T.Resize(size=448, max_size=2048), T.ToTensor()])
|
24 |
|
25 |
+
# Default hyperparameter values
|
26 |
+
DEFAULT_SIGMA = 100
|
27 |
+
DEFAULT_ALPHA = 0.95
|
28 |
+
DEFAULT_K = 400
|
29 |
+
DEFAULT_WSIZE = 224
|
30 |
+
DEFAULT_GAMMA = 3.0
|
31 |
+
DEFAULT_TAU = 0.01
|
32 |
+
|
33 |
+
# Function to reset hyperparameters to default values
|
34 |
+
def reset_hyperparams():
|
35 |
+
return DEFAULT_WSIZE, DEFAULT_K, DEFAULT_GAMMA, DEFAULT_ALPHA, DEFAULT_SIGMA, DEFAULT_TAU
|
36 |
+
|
37 |
@spaces.GPU
|
38 |
+
def segment_image(img: PIL.Image.Image, classnames: str, use_lposs_plus: bool | None,
|
39 |
+
winodw_size:int, k:int, gamma:float, alpha:float, sigma: float, tau:float) -> tuple[np.ndarray | PIL.Image.Image | str, list[tuple[np.ndarray | tuple[int, int, int, int], str]]]:
|
40 |
img_tensor = to_torch_tensor(PIL.Image.fromarray(img)).unsqueeze(0).to(device)
|
41 |
classnames = [c.strip() for c in classnames.split(",")]
|
42 |
num_classes = len(classnames)
|
43 |
+
|
44 |
+
winodw_size = (winodw_size, winodw_size)
|
45 |
+
stride = (winodw_size[0] // 2, winodw_size[1] // 2)
|
46 |
|
47 |
+
preds = lposs(maskclip, dino, img_tensor, classnames, window_size=winodw_size, window_stride=stride, sigma=1/sigma, lp_k_image=k, lp_gamma=gamma, lp_alpha=alpha)
|
48 |
if use_lposs_plus:
|
49 |
+
preds = lposs_plus(img_tensor, preds, tau=tau, alpha=alpha)
|
50 |
preds = F.interpolate(preds, size=img.shape[:-1], mode="bilinear", align_corners=False)
|
51 |
preds = F.softmax(preds * 100, dim=1).cpu().numpy()
|
52 |
return (img, [(preds[0, i, :, :], classnames[i]) for i in range(num_classes)])
|
53 |
|
54 |
+
with gr.Blocks() as demo:
|
55 |
+
gr.Markdown("# LPOSS: Label Propagation Over Patches and Pixels for Open-vocabulary Semantic Segmentation")
|
56 |
+
gr.Markdown("""<div align='center' style='margin: 1em 0;'>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
<a href='http://arxiv.org/abs/2503.19777' target='_blank' style='margin-right: 2em; text-decoration: none; font-weight: bold;'>
|
58 |
π arXiv
|
59 |
</a>
|
60 |
<a href='https://github.com/vladan-stojnic/LPOSS' target='_blank' style='text-decoration: none; font-weight: bold;'>
|
61 |
π» GitHub
|
62 |
</a>
|
63 |
+
</div>""")
|
64 |
+
gr.Markdown("Upload an image and specify the objects you want to segment by listing their names separated by commas.")
|
65 |
+
|
66 |
+
with gr.Row(variant="panel"):
|
67 |
+
with gr.Column(scale=1):
|
68 |
+
with gr.Row():
|
69 |
+
gr.Markdown("Hyper-parameters")
|
70 |
+
with gr.Row():
|
71 |
+
window_size = gr.Slider(minimum=112, maximum=448, value=DEFAULT_WSIZE, step=16, label="Window Size")
|
72 |
+
k = gr.Slider(minimum=50, maximum=800, value=DEFAULT_K, step=50, label="k")
|
73 |
+
gamma = gr.Slider(minimum=0.0, maximum=10.0, value=DEFAULT_GAMMA, step=0.5, label="Gamma")
|
74 |
+
alpha = gr.Slider(minimum=0.0, maximum=1.0, value=DEFAULT_ALPHA, step=0.05, label="Alpha")
|
75 |
+
sigma = gr.Slider(minimum=50, maximum=400, value=DEFAULT_SIGMA, step=10, label="Sigma")
|
76 |
+
tau = gr.Slider(minimum=0.0, maximum=1.0, value=DEFAULT_TAU, step=0.01, label="Tau")
|
77 |
+
with gr.Row():
|
78 |
+
reset_btn = gr.Button("Reset to Default Values")
|
79 |
+
|
80 |
+
with gr.Row():
|
81 |
+
with gr.Column(scale=2):
|
82 |
+
input_image = gr.Image(label="Input Image")
|
83 |
+
class_names = gr.Textbox(label="Class Names", info="Separate class names with commas")
|
84 |
+
use_lposs_plus = gr.Checkbox(label="Use LPOSS+", info="Enable pixel-level refinement using LPOSS+")
|
85 |
+
|
86 |
+
with gr.Column(scale=3):
|
87 |
+
output_image = gr.AnnotatedImage(label="Segmentation Results")
|
88 |
+
|
89 |
+
with gr.Row():
|
90 |
+
segment_btn = gr.Button("Segment Image")
|
91 |
+
|
92 |
+
reset_btn.click(fn=reset_hyperparams, outputs=[window_size, k, gamma, alpha, sigma, tau])
|
93 |
+
|
94 |
+
segment_btn.click(
|
95 |
+
fn=segment_image,
|
96 |
+
inputs=[input_image, class_names, use_lposs_plus, window_size, k, gamma, alpha, sigma, tau],
|
97 |
+
outputs=[output_image]
|
98 |
+
)
|
99 |
|
100 |
demo.launch()
|