explain-ViT / app.py
WwYc's picture
Update app.py
cb60283 verified
raw
history blame
1.55 kB
import gradio as gr
from PIL import Image
from torchvision import transforms
from visualization import generate_visualization
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
TRANSFORM = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]
)
use_threshold = [True, False]
def generate_viz(image, class_index=None, use_threshold=False):
if class_index is not None:
class_index = int(class_index)
print(f"Image: {image.size}")
print(f"Class: {class_index}")
print(f"use_threshold: {use_threshold}")
image_trans = TRANSFORM(image)
viz = generate_visualization(image_trans, class_index=class_index, use_threshold=use_threshold)
viz.savefig("visualization.png")
return Image.open("visualization.png").convert("RGB")
title = "Explain ViT 😊"
iface = gr.Interface(fn=generate_viz, inputs=[
gr.Image(type="pil", label="Input Image"),
gr.Dropdown(
list(use_threshold),
label="use_threshold",
),
gr.Number(label="Class Index", info="Class index to explain"),
],
outputs=gr.Image(),
title=title,
allow_flagging="never",
cache_examples=True,
examples=[
["ViT_DeiT/samples/catdog.png",None],
["ViT_DeiT/samples/catdog.png", 243],
["ViT_DeiT/samples/el2.png", None],
["ViT_DeiT/samples/el2.png", 340],
["ViT_DeiT/samples/dogbird.png", 161],
],
)
iface.launch(debug=True)