Spaces:
Sleeping
Sleeping
from fastai.vision.all import * | |
import gradio as gr | |
from captum.attr import Saliency | |
from torchvision import transforms | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
learn = load_learner('animal_model.pkl') | |
transform = transforms.Compose([ | |
transforms.Resize((128,128)), | |
transforms.ToTensor(), | |
]) | |
categories = learn.dls.vocab | |
def generate_saliency(image): | |
# Prepare the image for the model | |
img = PILImage.create(image) | |
# Get prediction | |
_, pred, probs = learn.predict(img) | |
# Create Captum interpretation object | |
interp = Saliency(learn.model) | |
# Transform and prepare image for saliency | |
tensor_image = transform(img).unsqueeze(0) | |
tensor_image = tensor_image.requires_grad_() | |
# Generate the saliency map | |
saliency_map = interp.attribute(tensor_image, target=pred) | |
# Process saliency map for visualization | |
saliency_np = saliency_map.squeeze().cpu().detach().numpy() | |
saliency_np = np.abs(saliency_np).sum(axis=0) | |
#saliency_np = (saliency_np - saliency_np.min()) / (saliency_np.max() - saliency_np.min()) | |
# Create heatmap | |
plt.figure(figsize=(10, 10)) | |
plt.imshow(saliency_np, cmap='viridis') | |
plt.axis('off') | |
plt.tight_layout() | |
plt.savefig('saliency_heatmap.png', pad_inches=0) | |
plt.close() | |
return ( | |
dict(zip(categories, map(float, probs))), | |
'saliency_heatmap.png', | |
'saliency_overlay.png' | |
) | |
# Gradio interface | |
image = gr.Image(type="pil") | |
label = gr.Label() | |
examples = ['polar_bear_real.jpg', 'polar_bear.jpg'] | |
interface = gr.Interface( | |
fn=generate_saliency, | |
inputs=image, | |
outputs=[ | |
gr.Label(label="Predictions"), | |
gr.Image(type="filepath", label="Saliency Heatmap") | |
], | |
examples=examples | |
) | |
interface.launch() |