yjernite commited on
Commit
7044e35
·
1 Parent(s): 72bd979

dedup images and add captions

Browse files
Files changed (1) hide show
  1. app.py +22 -14
app.py CHANGED
@@ -15,13 +15,6 @@ professions_dset = load_from_disk("professions")
15
  professions_df = professions_dset.to_pandas()
16
 
17
 
18
- def get_image(model, fname):
19
- return professions_dset.select(
20
- professions_df[
21
- (professions_df["image_path"] == fname) & (professions_df["model"] == model)
22
- ].index
23
- )["image"][0]
24
-
25
 
26
  clusters_dicts = dict(
27
  (num_cl, json.load(open(f"clusters/professions_to_clusters_{num_cl}.json")))
@@ -158,13 +151,28 @@ def make_profession_table(num_clusters, prof_names, mod_name, max_cols=8):
158
  .to_html()
159
  )
160
 
 
 
 
 
 
 
 
 
 
 
161
 
162
- def show_examplars(num_clusters, prof_name, cl_id):
 
163
  examplars_dict = clusters_dicts[num_clusters]["All"][prof_name][
164
  "cluster_examplars"
165
  ][str(cl_id)]
166
- l = list(chain(*[examplars_dict[k] for k in examplars_dict]))
167
- return [get_image(model, fname) for _, model, fname in l]
 
 
 
 
168
 
169
 
170
  with gr.Blocks(title=TITLE) as demo:
@@ -261,7 +269,7 @@ with gr.Blocks(title=TITLE) as demo:
261
  with gr.Row():
262
  examplars_plot = gr.Gallery(
263
  label="Profession images assigned to the selected cluster."
264
- ).style(grid=5, height="auto")
265
  demo.load(
266
  show_examplars,
267
  [
@@ -269,10 +277,10 @@ with gr.Blocks(title=TITLE) as demo:
269
  profession_choice_focus,
270
  cluster_id_focus,
271
  ],
272
- examplars_plot,
273
  queue=False,
274
  )
275
- for var in [cluster_id_focus]:
276
  var.change(
277
  show_examplars,
278
  [
@@ -280,7 +288,7 @@ with gr.Blocks(title=TITLE) as demo:
280
  profession_choice_focus,
281
  cluster_id_focus,
282
  ],
283
- examplars_plot,
284
  queue=False,
285
  )
286
 
 
15
  professions_df = professions_dset.to_pandas()
16
 
17
 
 
 
 
 
 
 
 
18
 
19
  clusters_dicts = dict(
20
  (num_cl, json.load(open(f"clusters/professions_to_clusters_{num_cl}.json")))
 
151
  .to_html()
152
  )
153
 
154
+ def get_image(model, fname, score):
155
+ return (
156
+ professions_dset.select(
157
+ professions_df[
158
+ (professions_df["image_path"] == fname) & (professions_df["model"] == model)
159
+ ].index
160
+ )["image"][0],
161
+ " ".join(fname.split("/")[0].split("_")[4:]) + f" | {score:.2f}" + f" | {models[model]}"
162
+ )
163
+
164
 
165
+ def show_examplars(num_clusters, prof_name, cl_id, confidence_threshold=0.5):
166
+ # only show images where the similarity to the centroid is > 0.7
167
  examplars_dict = clusters_dicts[num_clusters]["All"][prof_name][
168
  "cluster_examplars"
169
  ][str(cl_id)]
170
+ l = [tuple(img) for img in examplars_dict["close"] + examplars_dict["mid"][:2] + examplars_dict["far"]]
171
+ l = [img for i, img in enumerate(l) if img[0] > confidence_threshold and img not in l[:i]]
172
+ return (
173
+ [get_image(model, fname, score) for score, model, fname in l],
174
+ gr.update(label=f"Generations for profession ''{prof_name}'' assigned to cluster {cl_id} of {num_clusters}")
175
+ )
176
 
177
 
178
  with gr.Blocks(title=TITLE) as demo:
 
269
  with gr.Row():
270
  examplars_plot = gr.Gallery(
271
  label="Profession images assigned to the selected cluster."
272
+ ).style(grid=4, height="auto", container=True)
273
  demo.load(
274
  show_examplars,
275
  [
 
277
  profession_choice_focus,
278
  cluster_id_focus,
279
  ],
280
+ [examplars_plot, examplars_plot],
281
  queue=False,
282
  )
283
+ for var in [num_clusters_focus, profession_choice_focus, cluster_id_focus]:
284
  var.change(
285
  show_examplars,
286
  [
 
288
  profession_choice_focus,
289
  cluster_id_focus,
290
  ],
291
+ [examplars_plot, examplars_plot],
292
  queue=False,
293
  )
294