Spaces:
Sleeping
Sleeping
File size: 1,290 Bytes
971cce4 fd079a7 a041d6e 971cce4 fd079a7 971cce4 fd079a7 8038161 fd079a7 fc309f9 44c214e a041d6e ead720f fd079a7 8038161 fd079a7 071c27b fd079a7 cb60283 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
import gradio as gr
from PIL import Image
from torchvision import transforms
from explain import do_explain
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
TRANSFORM = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]
)
def generate_viz(image, class_index=None):
if class_index is not None:
class_index = int(class_index)
print(f"Image: {image.size}")
print(f"Class: {class_index}")
viz, pred = do_explain(TRANSFORM, image, class_index=class_index)
viz.savefig("visualization.png")
return Image.open("visualization.png").convert("RGB"), pred
title = "Explain ViT π"
iface = gr.Interface(fn=generate_viz, inputs=[
gr.Image(type="pil", label="Input Image"),
gr.Number(label="Class Index", info="Class index to explain"),
],
outputs=[ gr.Image(label="XAI-Image"), gr.Text(label="prob"),],
title=title,
allow_flagging="never",
cache_examples=True,
examples=[
["ViT_DeiT/samples/catdog.png",None],
["ViT_DeiT/samples/catdog.png", 243],
["ViT_DeiT/samples/el2.png", None],
["ViT_DeiT/samples/el2.png", 340],
],
)
iface.launch(debug=True) |