Spaces:
Sleeping
Sleeping
import gradio as gr | |
from PIL import Image | |
from torchvision import transforms | |
from visualization import generate_visualization | |
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) | |
TRANSFORM = transforms.Compose( | |
[ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
normalize, | |
] | |
) | |
use_threshold = [True, False] | |
def generate_viz(image, class_index=None, use_threshold=False): | |
if class_index is not None: | |
class_index = int(class_index) | |
print(f"Image: {image.size}") | |
print(f"Class: {class_index}") | |
print(f"use_threshold: {use_threshold}") | |
image_trans = TRANSFORM(image) | |
viz = generate_visualization(image_trans, class_index=class_index, use_threshold=use_threshold) | |
viz.savefig("visualization.png") | |
return Image.open("visualization.png").convert("RGB") | |
title = "Explain ViT π" | |
iface = gr.Interface(fn=generate_viz, inputs=[ | |
gr.Image(type="pil", label="Input Image"), | |
gr.Dropdown( | |
list(use_threshold), | |
label="use_threshold", | |
), | |
gr.Number(label="Class Index", info="Class index to explain"), | |
], | |
outputs=gr.Image(), | |
title=title, | |
allow_flagging="never", | |
cache_examples=True, | |
examples=[ | |
["ViT_DeiT/samples/catdog.png",None], | |
["ViT_DeiT/samples/catdog.png", 243], | |
["ViT_DeiT/samples/el2.png", None], | |
["ViT_DeiT/samples/el2.png", 340], | |
["ViT_DeiT/samples/dogbird.png", 161], | |
], | |
) | |
iface.launch(debug=True) |