stojnvla commited on
Commit
6fe4585
Β·
1 Parent(s): 4370bbd

put hyper-parameters

Browse files
Files changed (1) hide show
  1. app.py +58 -18
app.py CHANGED
@@ -22,39 +22,79 @@ maskclip = MaskClip().to(device)
22
  dino = DINO().to(device)
23
  to_torch_tensor = T.Compose([T.Resize(size=448, max_size=2048), T.ToTensor()])
24
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  @spaces.GPU
26
- def segment_image(img: PIL.Image.Image, classnames: str, use_lposs_plus: bool | None) -> tuple[np.ndarray | PIL.Image.Image | str, list[tuple[np.ndarray | tuple[int, int, int, int], str]]]:
 
27
  img_tensor = to_torch_tensor(PIL.Image.fromarray(img)).unsqueeze(0).to(device)
28
  classnames = [c.strip() for c in classnames.split(",")]
29
  num_classes = len(classnames)
 
 
 
30
 
31
- preds = lposs(maskclip, dino, img_tensor, classnames)
32
  if use_lposs_plus:
33
- preds = lposs_plus(img_tensor, preds)
34
  preds = F.interpolate(preds, size=img.shape[:-1], mode="bilinear", align_corners=False)
35
  preds = F.softmax(preds * 100, dim=1).cpu().numpy()
36
  return (img, [(preds[0, i, :, :], classnames[i]) for i in range(num_classes)])
37
 
38
- demo = gr.Interface(
39
- fn=segment_image,
40
- inputs=[
41
- gr.Image(label="Input Image"),
42
- gr.Textbox(label="Class Names", info="Separate class names with commas"),
43
- gr.Checkbox(label="Use LPOSS+", info="Enable pixel-level refinement using LPOSS+")
44
- ],
45
- outputs=[
46
- gr.AnnotatedImage(label="Segmentation Results")
47
- ],
48
- title="LPOSS: Label Propagation Over Patches and Pixels for Open-vocabulary Semantic Segmentation",
49
- article="""<div align='center' style='margin: 1em 0;'>
50
  <a href='http://arxiv.org/abs/2503.19777' target='_blank' style='margin-right: 2em; text-decoration: none; font-weight: bold;'>
51
  πŸ“„ arXiv
52
  </a>
53
  <a href='https://github.com/vladan-stojnic/LPOSS' target='_blank' style='text-decoration: none; font-weight: bold;'>
54
  πŸ’» GitHub
55
  </a>
56
- </div>""",
57
- description="Upload an image and specify the objects you want to segment by listing their names separated by commas.",
58
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  demo.launch()
 
22
  dino = DINO().to(device)
23
  to_torch_tensor = T.Compose([T.Resize(size=448, max_size=2048), T.ToTensor()])
24
 
25
+ # Default hyperparameter values
26
+ DEFAULT_SIGMA = 100
27
+ DEFAULT_ALPHA = 0.95
28
+ DEFAULT_K = 400
29
+ DEFAULT_WSIZE = 224
30
+ DEFAULT_GAMMA = 3.0
31
+ DEFAULT_TAU = 0.01
32
+
33
+ # Function to reset hyperparameters to default values
34
+ def reset_hyperparams():
35
+ return DEFAULT_WSIZE, DEFAULT_K, DEFAULT_GAMMA, DEFAULT_ALPHA, DEFAULT_SIGMA, DEFAULT_TAU
36
+
37
  @spaces.GPU
38
+ def segment_image(img: PIL.Image.Image, classnames: str, use_lposs_plus: bool | None,
39
+ winodw_size:int, k:int, gamma:float, alpha:float, sigma: float, tau:float) -> tuple[np.ndarray | PIL.Image.Image | str, list[tuple[np.ndarray | tuple[int, int, int, int], str]]]:
40
  img_tensor = to_torch_tensor(PIL.Image.fromarray(img)).unsqueeze(0).to(device)
41
  classnames = [c.strip() for c in classnames.split(",")]
42
  num_classes = len(classnames)
43
+
44
+ winodw_size = (winodw_size, winodw_size)
45
+ stride = (winodw_size[0] // 2, winodw_size[1] // 2)
46
 
47
+ preds = lposs(maskclip, dino, img_tensor, classnames, window_size=winodw_size, window_stride=stride, sigma=1/sigma, lp_k_image=k, lp_gamma=gamma, lp_alpha=alpha)
48
  if use_lposs_plus:
49
+ preds = lposs_plus(img_tensor, preds, tau=tau, alpha=alpha)
50
  preds = F.interpolate(preds, size=img.shape[:-1], mode="bilinear", align_corners=False)
51
  preds = F.softmax(preds * 100, dim=1).cpu().numpy()
52
  return (img, [(preds[0, i, :, :], classnames[i]) for i in range(num_classes)])
53
 
54
+ with gr.Blocks() as demo:
55
+ gr.Markdown("# LPOSS: Label Propagation Over Patches and Pixels for Open-vocabulary Semantic Segmentation")
56
+ gr.Markdown("""<div align='center' style='margin: 1em 0;'>
 
 
 
 
 
 
 
 
 
57
  <a href='http://arxiv.org/abs/2503.19777' target='_blank' style='margin-right: 2em; text-decoration: none; font-weight: bold;'>
58
  πŸ“„ arXiv
59
  </a>
60
  <a href='https://github.com/vladan-stojnic/LPOSS' target='_blank' style='text-decoration: none; font-weight: bold;'>
61
  πŸ’» GitHub
62
  </a>
63
+ </div>""")
64
+ gr.Markdown("Upload an image and specify the objects you want to segment by listing their names separated by commas.")
65
+
66
+ with gr.Row(variant="panel"):
67
+ with gr.Column(scale=1):
68
+ with gr.Row():
69
+ gr.Markdown("Hyper-parameters")
70
+ with gr.Row():
71
+ window_size = gr.Slider(minimum=112, maximum=448, value=DEFAULT_WSIZE, step=16, label="Window Size")
72
+ k = gr.Slider(minimum=50, maximum=800, value=DEFAULT_K, step=50, label="k")
73
+ gamma = gr.Slider(minimum=0.0, maximum=10.0, value=DEFAULT_GAMMA, step=0.5, label="Gamma")
74
+ alpha = gr.Slider(minimum=0.0, maximum=1.0, value=DEFAULT_ALPHA, step=0.05, label="Alpha")
75
+ sigma = gr.Slider(minimum=50, maximum=400, value=DEFAULT_SIGMA, step=10, label="Sigma")
76
+ tau = gr.Slider(minimum=0.0, maximum=1.0, value=DEFAULT_TAU, step=0.01, label="Tau")
77
+ with gr.Row():
78
+ reset_btn = gr.Button("Reset to Default Values")
79
+
80
+ with gr.Row():
81
+ with gr.Column(scale=2):
82
+ input_image = gr.Image(label="Input Image")
83
+ class_names = gr.Textbox(label="Class Names", info="Separate class names with commas")
84
+ use_lposs_plus = gr.Checkbox(label="Use LPOSS+", info="Enable pixel-level refinement using LPOSS+")
85
+
86
+ with gr.Column(scale=3):
87
+ output_image = gr.AnnotatedImage(label="Segmentation Results")
88
+
89
+ with gr.Row():
90
+ segment_btn = gr.Button("Segment Image")
91
+
92
+ reset_btn.click(fn=reset_hyperparams, outputs=[window_size, k, gamma, alpha, sigma, tau])
93
+
94
+ segment_btn.click(
95
+ fn=segment_image,
96
+ inputs=[input_image, class_names, use_lposs_plus, window_size, k, gamma, alpha, sigma, tau],
97
+ outputs=[output_image]
98
+ )
99
 
100
  demo.launch()