WwYc commited on
Commit
fd079a7
·
verified ·
1 Parent(s): 5196e4f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -4
app.py CHANGED
@@ -1,7 +1,52 @@
1
  import gradio as gr
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from PIL import Image
3
+ from torchvision import transforms
4
+ from visualization import generate_visualization
5
 
6
+ normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
7
+ TRANSFORM = transforms.Compose(
8
+ [
9
+ transforms.Resize(256),
10
+ transforms.CenterCrop(224),
11
+ transforms.ToTensor(),
12
+ normalize,
13
+ ]
14
+ )
15
 
16
+ use_threshold = [True, False]
17
+
18
+ def generate_viz(image, class_index=None, use_threshold=False):
19
+ if class_index is not None:
20
+ class_index = int(class_index)
21
+ print(f"Image: {image.size}")
22
+ print(f"Class: {class_index}")
23
+ print(f"use_threshold: {use_threshold}")
24
+ image_trans = TRANSFORM(image)
25
+ viz = generate_visualization(image_trans, class_index=class_index, use_threshold=use_threshold)
26
+ viz.savefig("visualization.png")
27
+ return Image.open("visualization.png").convert("RGB")
28
+
29
+ title = "Explain ViT 😊"
30
+
31
+ iface = gr.Interface(fn=generate_viz, inputs=[
32
+ gr.Image(type="pil", label="Input Image"),
33
+ gr.Dropdown(
34
+ list(use_threshold),
35
+ label="use_threshold",
36
+ ),
37
+ gr.Number(label="Class Index", info="Class index to explain"),
38
+ ],
39
+ outputs=gr.Image(),
40
+ title=title,
41
+ allow_flagging="never",
42
+ cache_examples=True,
43
+ examples=[
44
+ ["ViT_DeiT/samples/catdog.png",None],
45
+ ["ViT_DeiT/samples/catdog.png", 243],
46
+ ["ViT_DeiT/samples/el2.png", None],
47
+ ["ViT_DeiT/samples/el2.png", 340],
48
+ ["ViT_DeiT/samples/dogbird.png", 161],
49
+ ],
50
+ )
51
+
52
+ iface.launch(debug=True)