Spaces:
Running
on
Zero
Running
on
Zero
add kway
Browse files
app.py
CHANGED
@@ -927,6 +927,44 @@ def ncut_run(
|
|
927 |
|
928 |
return to_pil_images(rgb_all), logging_str
|
929 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
930 |
# ailgnedcut
|
931 |
if not directed:
|
932 |
only_eigvecs = kwargs.get("only_eigvecs", False)
|
@@ -1318,6 +1356,7 @@ def run_fn(
|
|
1318 |
return_eigvec_and_rgb=False,
|
1319 |
normalize_eigvec_return=False,
|
1320 |
separate_fg_bg=False,
|
|
|
1321 |
):
|
1322 |
# print(node_type2, head_index_text, make_symmetric)
|
1323 |
progress=gr.Progress()
|
@@ -1463,6 +1502,7 @@ def run_fn(
|
|
1463 |
"return_eigvec_and_rgb": return_eigvec_and_rgb,
|
1464 |
"normalize_eigvec_return": normalize_eigvec_return,
|
1465 |
"separate_fg_bg": separate_fg_bg,
|
|
|
1466 |
}
|
1467 |
# print(kwargs)
|
1468 |
|
@@ -4348,6 +4388,44 @@ with demo:
|
|
4348 |
outputs=[output_gallery, logging_text],
|
4349 |
)
|
4350 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4351 |
|
4352 |
with gr.Tab('Sub-cluster (dev)', visible=False) as sub_cluster_tab:
|
4353 |
with gr.Row():
|
|
|
927 |
|
928 |
return to_pil_images(rgb_all), logging_str
|
929 |
|
930 |
+
kway = kwargs.get("kway", False)
|
931 |
+
if kway:
|
932 |
+
only_eigvecs = True
|
933 |
+
rgb, _logging_str, eigvecs = compute_ncut(
|
934 |
+
features,
|
935 |
+
num_eig=num_eig,
|
936 |
+
num_sample_ncut=num_sample_ncut,
|
937 |
+
affinity_focal_gamma=affinity_focal_gamma,
|
938 |
+
knn_ncut=knn_ncut,
|
939 |
+
knn_tsne=knn_tsne,
|
940 |
+
num_sample_tsne=num_sample_tsne,
|
941 |
+
embedding_method=embedding_method,
|
942 |
+
embedding_metric=embedding_metric,
|
943 |
+
perplexity=perplexity,
|
944 |
+
n_neighbors=n_neighbors,
|
945 |
+
min_dist=min_dist,
|
946 |
+
sampling_method=sampling_method,
|
947 |
+
indirect_connection=indirect_connection,
|
948 |
+
make_orthogonal=make_orthogonal,
|
949 |
+
metric=ncut_metric,
|
950 |
+
only_eigvecs=only_eigvecs,
|
951 |
+
)
|
952 |
+
from ncut_pytorch import kway_ncut
|
953 |
+
kway_onehot = kway_ncut(eigvecs) # [N, K]
|
954 |
+
kway_indices = kway_onehot.argmax(dim=-1) # [N]
|
955 |
+
kway_indices = kway_indices.cpu().numpy()
|
956 |
+
if kway_indices.max() > 10:
|
957 |
+
cm = plt.colormaps['tab20']
|
958 |
+
rgb = cm(kway_indices / 20)
|
959 |
+
else:
|
960 |
+
cm = plt.colormaps['tab10']
|
961 |
+
rgb = cm(kway_indices / 10)
|
962 |
+
if kway_indices.max() > 20:
|
963 |
+
gr.Error("Too many clusters for kway_ncut")
|
964 |
+
rgb = rgb[:, :3]
|
965 |
+
rgb = rgb.reshape(*features.shape[:-1], 3)
|
966 |
+
return to_pil_images(rgb), logging_str
|
967 |
+
|
968 |
# ailgnedcut
|
969 |
if not directed:
|
970 |
only_eigvecs = kwargs.get("only_eigvecs", False)
|
|
|
1356 |
return_eigvec_and_rgb=False,
|
1357 |
normalize_eigvec_return=False,
|
1358 |
separate_fg_bg=False,
|
1359 |
+
kway=False,
|
1360 |
):
|
1361 |
# print(node_type2, head_index_text, make_symmetric)
|
1362 |
progress=gr.Progress()
|
|
|
1502 |
"return_eigvec_and_rgb": return_eigvec_and_rgb,
|
1503 |
"normalize_eigvec_return": normalize_eigvec_return,
|
1504 |
"separate_fg_bg": separate_fg_bg,
|
1505 |
+
"kway": kway,
|
1506 |
}
|
1507 |
# print(kwargs)
|
1508 |
|
|
|
4388 |
outputs=[output_gallery, logging_text],
|
4389 |
)
|
4390 |
|
4391 |
+
|
4392 |
+
with gr.Tab('K-way'):
|
4393 |
+
|
4394 |
+
with gr.Row():
|
4395 |
+
with gr.Column(scale=5, min_width=200):
|
4396 |
+
input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section()
|
4397 |
+
num_images_slider.value = 30
|
4398 |
+
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False)
|
4399 |
+
|
4400 |
+
with gr.Column(scale=5, min_width=200):
|
4401 |
+
output_gallery = make_output_images_section()
|
4402 |
+
# cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[2], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
|
4403 |
+
[
|
4404 |
+
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
|
4405 |
+
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|
4406 |
+
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
4407 |
+
perplexity_slider, n_neighbors_slider, min_dist_slider,
|
4408 |
+
sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt
|
4409 |
+
] = make_parameters_section()
|
4410 |
+
num_eig_slider.value = 6
|
4411 |
+
num_eig_slider.maximum = 20
|
4412 |
+
|
4413 |
+
false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
|
4414 |
+
no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
|
4415 |
+
|
4416 |
+
submit_button.click(
|
4417 |
+
partial(run_fn, n_ret=1, plot_clusters=False, kway=True),
|
4418 |
+
inputs=[
|
4419 |
+
input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
|
4420 |
+
positive_prompt, negative_prompt,
|
4421 |
+
false_placeholder, no_prompt, no_prompt, no_prompt,
|
4422 |
+
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|
4423 |
+
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
4424 |
+
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown,
|
4425 |
+
*[false_placeholder]*12,
|
4426 |
+
],
|
4427 |
+
outputs=[output_gallery, logging_text],
|
4428 |
+
)
|
4429 |
|
4430 |
with gr.Tab('Sub-cluster (dev)', visible=False) as sub_cluster_tab:
|
4431 |
with gr.Row():
|