WwYc commited on
Commit
8038161
·
verified ·
1 Parent(s): c32bf50

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -9
app.py CHANGED
@@ -14,16 +14,14 @@ TRANSFORM = transforms.Compose(
14
  ]
15
  )
16
 
17
- use_threshold = [True, False]
18
 
19
- def generate_viz(image, class_index=None, use_threshold=False):
20
  if class_index is not None:
21
  class_index = int(class_index)
22
  print(f"Image: {image.size}")
23
  print(f"Class: {class_index}")
24
- print(f"use_threshold: {use_threshold}")
25
 
26
- viz = do_explain(TRANSFORM, image, class_index=class_index, use_threshold=use_threshold)
27
  viz.savefig("visualization.png")
28
  return Image.open("visualization.png").convert("RGB")
29
 
@@ -31,10 +29,7 @@ title = "Explain ViT 😊"
31
 
32
  iface = gr.Interface(fn=generate_viz, inputs=[
33
  gr.Image(type="pil", label="Input Image"),
34
- gr.Dropdown(
35
- list(use_threshold),
36
- label="use_threshold",
37
- ),
38
  gr.Number(label="Class Index", info="Class index to explain"),
39
  ],
40
  outputs=gr.Image(),
@@ -46,7 +41,6 @@ iface = gr.Interface(fn=generate_viz, inputs=[
46
  ["ViT_DeiT/samples/catdog.png", 243],
47
  ["ViT_DeiT/samples/el2.png", None],
48
  ["ViT_DeiT/samples/el2.png", 340],
49
- ["ViT_DeiT/samples/dogbird.png", 161],
50
  ],
51
  )
52
 
 
14
  ]
15
  )
16
 
 
17
 
18
+ def generate_viz(image, class_index=None):
19
  if class_index is not None:
20
  class_index = int(class_index)
21
  print(f"Image: {image.size}")
22
  print(f"Class: {class_index}")
 
23
 
24
+ viz = do_explain(TRANSFORM, image, class_index=class_index)
25
  viz.savefig("visualization.png")
26
  return Image.open("visualization.png").convert("RGB")
27
 
 
29
 
30
  iface = gr.Interface(fn=generate_viz, inputs=[
31
  gr.Image(type="pil", label="Input Image"),
32
+
 
 
 
33
  gr.Number(label="Class Index", info="Class index to explain"),
34
  ],
35
  outputs=gr.Image(),
 
41
  ["ViT_DeiT/samples/catdog.png", 243],
42
  ["ViT_DeiT/samples/el2.png", None],
43
  ["ViT_DeiT/samples/el2.png", 340],
 
44
  ],
45
  )
46