Spaces:
Runtime error
Runtime error
yjernite
commited on
Commit
·
7044e35
1
Parent(s):
72bd979
dedup images and add captions
Browse files
app.py
CHANGED
@@ -15,13 +15,6 @@ professions_dset = load_from_disk("professions")
|
|
15 |
professions_df = professions_dset.to_pandas()
|
16 |
|
17 |
|
18 |
-
def get_image(model, fname):
|
19 |
-
return professions_dset.select(
|
20 |
-
professions_df[
|
21 |
-
(professions_df["image_path"] == fname) & (professions_df["model"] == model)
|
22 |
-
].index
|
23 |
-
)["image"][0]
|
24 |
-
|
25 |
|
26 |
clusters_dicts = dict(
|
27 |
(num_cl, json.load(open(f"clusters/professions_to_clusters_{num_cl}.json")))
|
@@ -158,13 +151,28 @@ def make_profession_table(num_clusters, prof_names, mod_name, max_cols=8):
|
|
158 |
.to_html()
|
159 |
)
|
160 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
|
162 |
-
def show_examplars(num_clusters, prof_name, cl_id):
|
|
|
163 |
examplars_dict = clusters_dicts[num_clusters]["All"][prof_name][
|
164 |
"cluster_examplars"
|
165 |
][str(cl_id)]
|
166 |
-
l =
|
167 |
-
|
|
|
|
|
|
|
|
|
168 |
|
169 |
|
170 |
with gr.Blocks(title=TITLE) as demo:
|
@@ -261,7 +269,7 @@ with gr.Blocks(title=TITLE) as demo:
|
|
261 |
with gr.Row():
|
262 |
examplars_plot = gr.Gallery(
|
263 |
label="Profession images assigned to the selected cluster."
|
264 |
-
).style(grid=
|
265 |
demo.load(
|
266 |
show_examplars,
|
267 |
[
|
@@ -269,10 +277,10 @@ with gr.Blocks(title=TITLE) as demo:
|
|
269 |
profession_choice_focus,
|
270 |
cluster_id_focus,
|
271 |
],
|
272 |
-
examplars_plot,
|
273 |
queue=False,
|
274 |
)
|
275 |
-
for var in [cluster_id_focus]:
|
276 |
var.change(
|
277 |
show_examplars,
|
278 |
[
|
@@ -280,7 +288,7 @@ with gr.Blocks(title=TITLE) as demo:
|
|
280 |
profession_choice_focus,
|
281 |
cluster_id_focus,
|
282 |
],
|
283 |
-
examplars_plot,
|
284 |
queue=False,
|
285 |
)
|
286 |
|
|
|
15 |
professions_df = professions_dset.to_pandas()
|
16 |
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
clusters_dicts = dict(
|
20 |
(num_cl, json.load(open(f"clusters/professions_to_clusters_{num_cl}.json")))
|
|
|
151 |
.to_html()
|
152 |
)
|
153 |
|
154 |
+
def get_image(model, fname, score):
|
155 |
+
return (
|
156 |
+
professions_dset.select(
|
157 |
+
professions_df[
|
158 |
+
(professions_df["image_path"] == fname) & (professions_df["model"] == model)
|
159 |
+
].index
|
160 |
+
)["image"][0],
|
161 |
+
" ".join(fname.split("/")[0].split("_")[4:]) + f" | {score:.2f}" + f" | {models[model]}"
|
162 |
+
)
|
163 |
+
|
164 |
|
165 |
+
def show_examplars(num_clusters, prof_name, cl_id, confidence_threshold=0.5):
|
166 |
+
# only show images where the similarity to the centroid is > 0.7
|
167 |
examplars_dict = clusters_dicts[num_clusters]["All"][prof_name][
|
168 |
"cluster_examplars"
|
169 |
][str(cl_id)]
|
170 |
+
l = [tuple(img) for img in examplars_dict["close"] + examplars_dict["mid"][:2] + examplars_dict["far"]]
|
171 |
+
l = [img for i, img in enumerate(l) if img[0] > confidence_threshold and img not in l[:i]]
|
172 |
+
return (
|
173 |
+
[get_image(model, fname, score) for score, model, fname in l],
|
174 |
+
gr.update(label=f"Generations for profession ''{prof_name}'' assigned to cluster {cl_id} of {num_clusters}")
|
175 |
+
)
|
176 |
|
177 |
|
178 |
with gr.Blocks(title=TITLE) as demo:
|
|
|
269 |
with gr.Row():
|
270 |
examplars_plot = gr.Gallery(
|
271 |
label="Profession images assigned to the selected cluster."
|
272 |
+
).style(grid=4, height="auto", container=True)
|
273 |
demo.load(
|
274 |
show_examplars,
|
275 |
[
|
|
|
277 |
profession_choice_focus,
|
278 |
cluster_id_focus,
|
279 |
],
|
280 |
+
[examplars_plot, examplars_plot],
|
281 |
queue=False,
|
282 |
)
|
283 |
+
for var in [num_clusters_focus, profession_choice_focus, cluster_id_focus]:
|
284 |
var.change(
|
285 |
show_examplars,
|
286 |
[
|
|
|
288 |
profession_choice_focus,
|
289 |
cluster_id_focus,
|
290 |
],
|
291 |
+
[examplars_plot, examplars_plot],
|
292 |
queue=False,
|
293 |
)
|
294 |
|