from fastai.vision.all import * import gradio as gr from captum.attr import Saliency from torchvision import transforms import matplotlib.pyplot as plt import numpy as np import torch learn = load_learner('animal_model.pkl') transform = transforms.Compose([ transforms.Resize((128,128)), transforms.ToTensor(), ]) categories = learn.dls.vocab def generate_saliency(image): # Prepare the image for the model img = PILImage.create(image) # Get prediction _, pred, probs = learn.predict(img) # Create Captum interpretation object interp = Saliency(learn.model) # Transform and prepare image for saliency tensor_image = transform(img).unsqueeze(0) tensor_image = tensor_image.requires_grad_() # Generate the saliency map saliency_map = interp.attribute(tensor_image, target=pred) # Process saliency map for visualization saliency_np = saliency_map.squeeze().cpu().detach().numpy() saliency_np = np.abs(saliency_np).sum(axis=0) #saliency_np = (saliency_np - saliency_np.min()) / (saliency_np.max() - saliency_np.min()) # Create heatmap plt.figure(figsize=(10, 10)) plt.imshow(saliency_np, cmap='viridis') plt.axis('off') plt.tight_layout() plt.savefig('saliency_heatmap.png', pad_inches=0) plt.close() return ( dict(zip(categories, map(float, probs))), 'saliency_heatmap.png', 'saliency_overlay.png' ) # Gradio interface image = gr.Image(type="pil") label = gr.Label() examples = ['polar_bear_real.jpg', 'polar_bear.jpg'] interface = gr.Interface( fn=generate_saliency, inputs=image, outputs=[ gr.Label(label="Predictions"), gr.Image(type="filepath", label="Saliency Heatmap") ], examples=examples ) interface.launch()