import gradio as gr from PIL import Image from torchvision import transforms from visualization import generate_visualization normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) TRANSFORM = transforms.Compose( [ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ] ) use_threshold = [True, False] def generate_viz(image, class_index=None, use_threshold=False): if class_index is not None: class_index = int(class_index) print(f"Image: {image.size}") print(f"Class: {class_index}") print(f"use_threshold: {use_threshold}") image_trans = TRANSFORM(image) viz = generate_visualization(image_trans, class_index=class_index, use_threshold=use_threshold) viz.savefig("visualization.png") return Image.open("visualization.png").convert("RGB") title = "Explain ViT 😊" iface = gr.Interface(fn=generate_viz, inputs=[ gr.Image(type="pil", label="Input Image"), gr.Dropdown( list(use_threshold), label="use_threshold", ), gr.Number(label="Class Index", info="Class index to explain"), ], outputs=gr.Image(), title=title, allow_flagging="never", cache_examples=True, examples=[ ["ViT_DeiT/samples/catdog.png",None], ["ViT_DeiT/samples/catdog.png", 243], ["ViT_DeiT/samples/el2.png", None], ["ViT_DeiT/samples/el2.png", 340], ["ViT_DeiT/samples/dogbird.png", 161], ], ) iface.launch(debug=True)