stojnvla commited on
Commit
be73ac3
·
1 Parent(s): 6fe4585

update hyper

Browse files
Files changed (2) hide show
  1. app.py +12 -10
  2. lposs.py +2 -2
app.py CHANGED
@@ -29,14 +29,15 @@ DEFAULT_K = 400
29
  DEFAULT_WSIZE = 224
30
  DEFAULT_GAMMA = 3.0
31
  DEFAULT_TAU = 0.01
 
32
 
33
  # Function to reset hyperparameters to default values
34
  def reset_hyperparams():
35
- return DEFAULT_WSIZE, DEFAULT_K, DEFAULT_GAMMA, DEFAULT_ALPHA, DEFAULT_SIGMA, DEFAULT_TAU
36
 
37
  @spaces.GPU
38
  def segment_image(img: PIL.Image.Image, classnames: str, use_lposs_plus: bool | None,
39
- winodw_size:int, k:int, gamma:float, alpha:float, sigma: float, tau:float) -> tuple[np.ndarray | PIL.Image.Image | str, list[tuple[np.ndarray | tuple[int, int, int, int], str]]]:
40
  img_tensor = to_torch_tensor(PIL.Image.fromarray(img)).unsqueeze(0).to(device)
41
  classnames = [c.strip() for c in classnames.split(",")]
42
  num_classes = len(classnames)
@@ -46,7 +47,7 @@ def segment_image(img: PIL.Image.Image, classnames: str, use_lposs_plus: bool |
46
 
47
  preds = lposs(maskclip, dino, img_tensor, classnames, window_size=winodw_size, window_stride=stride, sigma=1/sigma, lp_k_image=k, lp_gamma=gamma, lp_alpha=alpha)
48
  if use_lposs_plus:
49
- preds = lposs_plus(img_tensor, preds, tau=tau, alpha=alpha)
50
  preds = F.interpolate(preds, size=img.shape[:-1], mode="bilinear", align_corners=False)
51
  preds = F.softmax(preds * 100, dim=1).cpu().numpy()
52
  return (img, [(preds[0, i, :, :], classnames[i]) for i in range(num_classes)])
@@ -69,11 +70,12 @@ with gr.Blocks() as demo:
69
  gr.Markdown("Hyper-parameters")
70
  with gr.Row():
71
  window_size = gr.Slider(minimum=112, maximum=448, value=DEFAULT_WSIZE, step=16, label="Window Size")
72
- k = gr.Slider(minimum=50, maximum=800, value=DEFAULT_K, step=50, label="k")
73
- gamma = gr.Slider(minimum=0.0, maximum=10.0, value=DEFAULT_GAMMA, step=0.5, label="Gamma")
74
- alpha = gr.Slider(minimum=0.0, maximum=1.0, value=DEFAULT_ALPHA, step=0.05, label="Alpha")
75
- sigma = gr.Slider(minimum=50, maximum=400, value=DEFAULT_SIGMA, step=10, label="Sigma")
76
- tau = gr.Slider(minimum=0.0, maximum=1.0, value=DEFAULT_TAU, step=0.01, label="Tau")
 
77
  with gr.Row():
78
  reset_btn = gr.Button("Reset to Default Values")
79
 
@@ -89,11 +91,11 @@ with gr.Blocks() as demo:
89
  with gr.Row():
90
  segment_btn = gr.Button("Segment Image")
91
 
92
- reset_btn.click(fn=reset_hyperparams, outputs=[window_size, k, gamma, alpha, sigma, tau])
93
 
94
  segment_btn.click(
95
  fn=segment_image,
96
- inputs=[input_image, class_names, use_lposs_plus, window_size, k, gamma, alpha, sigma, tau],
97
  outputs=[output_image]
98
  )
99
 
 
29
  DEFAULT_WSIZE = 224
30
  DEFAULT_GAMMA = 3.0
31
  DEFAULT_TAU = 0.01
32
+ DEFAULT_R = 13
33
 
34
  # Function to reset hyperparameters to default values
35
  def reset_hyperparams():
36
+ return DEFAULT_WSIZE, DEFAULT_K, DEFAULT_GAMMA, DEFAULT_ALPHA, DEFAULT_SIGMA, DEFAULT_TAU, DEFAULT_R
37
 
38
  @spaces.GPU
39
  def segment_image(img: PIL.Image.Image, classnames: str, use_lposs_plus: bool | None,
40
+ winodw_size:int, k:int, gamma:float, alpha:float, sigma: float, tau:float, r:int) -> tuple[np.ndarray | PIL.Image.Image | str, list[tuple[np.ndarray | tuple[int, int, int, int], str]]]:
41
  img_tensor = to_torch_tensor(PIL.Image.fromarray(img)).unsqueeze(0).to(device)
42
  classnames = [c.strip() for c in classnames.split(",")]
43
  num_classes = len(classnames)
 
47
 
48
  preds = lposs(maskclip, dino, img_tensor, classnames, window_size=winodw_size, window_stride=stride, sigma=1/sigma, lp_k_image=k, lp_gamma=gamma, lp_alpha=alpha)
49
  if use_lposs_plus:
50
+ preds = lposs_plus(img_tensor, preds, tau=tau, alpha=alpha, r=r)
51
  preds = F.interpolate(preds, size=img.shape[:-1], mode="bilinear", align_corners=False)
52
  preds = F.softmax(preds * 100, dim=1).cpu().numpy()
53
  return (img, [(preds[0, i, :, :], classnames[i]) for i in range(num_classes)])
 
70
  gr.Markdown("Hyper-parameters")
71
  with gr.Row():
72
  window_size = gr.Slider(minimum=112, maximum=448, value=DEFAULT_WSIZE, step=16, label="Window Size")
73
+ k = gr.Slider(minimum=50, maximum=800, value=DEFAULT_K, step=50, label="k (LPOSS number of graph neighbors)")
74
+ gamma = gr.Slider(minimum=0.0, maximum=10.0, value=DEFAULT_GAMMA, step=0.5, label="γ (LPOSS graph edge weight)")
75
+ sigma = gr.Slider(minimum=50, maximum=400, value=DEFAULT_SIGMA, step=10, label="σ (LPOSS spatial affinity weight)")
76
+ tau = gr.Slider(minimum=0.0, maximum=1.0, value=DEFAULT_TAU, step=0.01, label="τ (LPOSS+ appearance affinity weight)")
77
+ r = gr.Slider(minimum=3, maximum=15, value=DEFAULT_R, step=2, label="r (LPOSS+ kernel size)")
78
+ alpha = gr.Slider(minimum=0.0, maximum=1.0, value=DEFAULT_ALPHA, step=0.05, label="α (amount of propagation)")
79
  with gr.Row():
80
  reset_btn = gr.Button("Reset to Default Values")
81
 
 
91
  with gr.Row():
92
  segment_btn = gr.Button("Segment Image")
93
 
94
+ reset_btn.click(fn=reset_hyperparams, outputs=[window_size, k, gamma, alpha, sigma, tau, r])
95
 
96
  segment_btn.click(
97
  fn=segment_image,
98
+ inputs=[input_image, class_names, use_lposs_plus, window_size, k, gamma, alpha, sigma, tau, r],
99
  outputs=[output_image]
100
  )
101
 
lposs.py CHANGED
@@ -277,13 +277,13 @@ def get_laplacian(rows, cols, data, N, alpha=0.99):
277
  return L
278
 
279
 
280
- def lposs_plus(img, preds, tau=0.01, alpha=0.95):
281
  preds = preds[0, ...]
282
  num_classes, h_img, w_img = preds.shape
283
  preds = preds.permute((1, 2, 0))
284
  preds = preds.reshape((h_img*w_img, -1))
285
 
286
- rows, cols, pixel_pixel_data, locs = get_pixel_connections(img, neigh=6)
287
  pixel_pixel_data = torch.sqrt(pixel_pixel_data)
288
  pixel_pixel_data = torch.exp(-pixel_pixel_data / tau)
289
  L = get_laplacian(rows, cols, pixel_pixel_data, preds.shape[0], alpha=alpha)
 
277
  return L
278
 
279
 
280
+ def lposs_plus(img, preds, tau=0.01, alpha=0.95, r=13):
281
  preds = preds[0, ...]
282
  num_classes, h_img, w_img = preds.shape
283
  preds = preds.permute((1, 2, 0))
284
  preds = preds.reshape((h_img*w_img, -1))
285
 
286
+ rows, cols, pixel_pixel_data, locs = get_pixel_connections(img, neigh=r//2)
287
  pixel_pixel_data = torch.sqrt(pixel_pixel_data)
288
  pixel_pixel_data = torch.exp(-pixel_pixel_data / tau)
289
  L = get_laplacian(rows, cols, pixel_pixel_data, preds.shape[0], alpha=alpha)