huzey commited on
Commit
bbd0121
·
1 Parent(s): 4f504b9
Files changed (1) hide show
  1. app.py +78 -0
app.py CHANGED
@@ -927,6 +927,44 @@ def ncut_run(
927
 
928
  return to_pil_images(rgb_all), logging_str
929
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
930
  # ailgnedcut
931
  if not directed:
932
  only_eigvecs = kwargs.get("only_eigvecs", False)
@@ -1318,6 +1356,7 @@ def run_fn(
1318
  return_eigvec_and_rgb=False,
1319
  normalize_eigvec_return=False,
1320
  separate_fg_bg=False,
 
1321
  ):
1322
  # print(node_type2, head_index_text, make_symmetric)
1323
  progress=gr.Progress()
@@ -1463,6 +1502,7 @@ def run_fn(
1463
  "return_eigvec_and_rgb": return_eigvec_and_rgb,
1464
  "normalize_eigvec_return": normalize_eigvec_return,
1465
  "separate_fg_bg": separate_fg_bg,
 
1466
  }
1467
  # print(kwargs)
1468
 
@@ -4348,6 +4388,44 @@ with demo:
4348
  outputs=[output_gallery, logging_text],
4349
  )
4350
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4351
 
4352
  with gr.Tab('Sub-cluster (dev)', visible=False) as sub_cluster_tab:
4353
  with gr.Row():
 
927
 
928
  return to_pil_images(rgb_all), logging_str
929
 
930
+ kway = kwargs.get("kway", False)
931
+ if kway:
932
+ only_eigvecs = True
933
+ rgb, _logging_str, eigvecs = compute_ncut(
934
+ features,
935
+ num_eig=num_eig,
936
+ num_sample_ncut=num_sample_ncut,
937
+ affinity_focal_gamma=affinity_focal_gamma,
938
+ knn_ncut=knn_ncut,
939
+ knn_tsne=knn_tsne,
940
+ num_sample_tsne=num_sample_tsne,
941
+ embedding_method=embedding_method,
942
+ embedding_metric=embedding_metric,
943
+ perplexity=perplexity,
944
+ n_neighbors=n_neighbors,
945
+ min_dist=min_dist,
946
+ sampling_method=sampling_method,
947
+ indirect_connection=indirect_connection,
948
+ make_orthogonal=make_orthogonal,
949
+ metric=ncut_metric,
950
+ only_eigvecs=only_eigvecs,
951
+ )
952
+ from ncut_pytorch import kway_ncut
953
+ kway_onehot = kway_ncut(eigvecs) # [N, K]
954
+ kway_indices = kway_onehot.argmax(dim=-1) # [N]
955
+ kway_indices = kway_indices.cpu().numpy()
956
+ if kway_indices.max() > 10:
957
+ cm = plt.colormaps['tab20']
958
+ rgb = cm(kway_indices / 20)
959
+ else:
960
+ cm = plt.colormaps['tab10']
961
+ rgb = cm(kway_indices / 10)
962
+ if kway_indices.max() > 20:
963
+ gr.Error("Too many clusters for kway_ncut")
964
+ rgb = rgb[:, :3]
965
+ rgb = rgb.reshape(*features.shape[:-1], 3)
966
+ return to_pil_images(rgb), logging_str
967
+
968
  # ailgnedcut
969
  if not directed:
970
  only_eigvecs = kwargs.get("only_eigvecs", False)
 
1356
  return_eigvec_and_rgb=False,
1357
  normalize_eigvec_return=False,
1358
  separate_fg_bg=False,
1359
+ kway=False,
1360
  ):
1361
  # print(node_type2, head_index_text, make_symmetric)
1362
  progress=gr.Progress()
 
1502
  "return_eigvec_and_rgb": return_eigvec_and_rgb,
1503
  "normalize_eigvec_return": normalize_eigvec_return,
1504
  "separate_fg_bg": separate_fg_bg,
1505
+ "kway": kway,
1506
  }
1507
  # print(kwargs)
1508
 
 
4388
  outputs=[output_gallery, logging_text],
4389
  )
4390
 
4391
+
4392
+ with gr.Tab('K-way'):
4393
+
4394
+ with gr.Row():
4395
+ with gr.Column(scale=5, min_width=200):
4396
+ input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section()
4397
+ num_images_slider.value = 30
4398
+ logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False)
4399
+
4400
+ with gr.Column(scale=5, min_width=200):
4401
+ output_gallery = make_output_images_section()
4402
+ # cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[2], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
4403
+ [
4404
+ model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
4405
+ affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
4406
+ embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
4407
+ perplexity_slider, n_neighbors_slider, min_dist_slider,
4408
+ sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt
4409
+ ] = make_parameters_section()
4410
+ num_eig_slider.value = 6
4411
+ num_eig_slider.maximum = 20
4412
+
4413
+ false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
4414
+ no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
4415
+
4416
+ submit_button.click(
4417
+ partial(run_fn, n_ret=1, plot_clusters=False, kway=True),
4418
+ inputs=[
4419
+ input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
4420
+ positive_prompt, negative_prompt,
4421
+ false_placeholder, no_prompt, no_prompt, no_prompt,
4422
+ affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
4423
+ embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
4424
+ perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown,
4425
+ *[false_placeholder]*12,
4426
+ ],
4427
+ outputs=[output_gallery, logging_text],
4428
+ )
4429
 
4430
  with gr.Tab('Sub-cluster (dev)', visible=False) as sub_cluster_tab:
4431
  with gr.Row():